From fa500606eb8d4c703ef2ba1d06ae446252d001d8 Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Fri, 11 Jul 2025 19:04:51 +0100 Subject: [PATCH 01/11] ideation around rate limiting logic --- Cargo.lock | 21 ++++++++++ lib/llm/Cargo.toml | 3 ++ lib/llm/src/http/service.rs | 1 + lib/llm/src/http/service/openai.rs | 53 ++++++++++++++++++++++++++ lib/llm/src/http/service/service_v2.rs | 22 ++++++++++- 5 files changed, 98 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0782b14612..7250793a26 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -536,6 +536,12 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "base64" version = "0.22.1" @@ -1798,6 +1804,7 @@ dependencies = [ "futures", "galil-seiferas", "ggus", + "hdrhistogram", "hf-hub", "humantime", "insta", @@ -2861,6 +2868,20 @@ dependencies = [ "foldhash", ] +[[package]] +name = "hdrhistogram" +version = "7.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "765c9198f173dd59ce26ff9f95ef0aafd0a0fe01fb9d72841bc5066a4c06511d" +dependencies = [ + "base64 0.21.7", + "byteorder", + "crossbeam-channel", + "flate2", + "nom", + "num-traits", +] + [[package]] name = "heck" version = "0.4.1" diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index 62403dd029..988a0d21d7 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -108,6 +108,9 @@ tokenizers = { version = "0.21.1", default-features = false, features = [ ] } sentencepiece = { version = "0.11.2", optional = true } +# metrics +hdrhistogram = "7.5.4" + # backend galil-seiferas = { version = "0.1" } toktrie = { version = "0.6.28" } diff --git a/lib/llm/src/http/service.rs b/lib/llm/src/http/service.rs index 9c4081f6ef..6828132966 100644 --- a/lib/llm/src/http/service.rs +++ b/lib/llm/src/http/service.rs @@ -23,6 +23,7 @@ mod openai; pub mod error; pub mod health; pub mod metrics; +pub mod rate_limiter; pub mod service_v2; pub use axum; diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index 025b79024f..41b1d70b19 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -11,10 +11,12 @@ use std::{ use axum::{ extract::State, http::StatusCode, + middleware::Next, response::{ sse::{Event, KeepAlive, Sse}, IntoResponse, Response, }, + http::Request, routing::{get, post}, Json, Router, }; @@ -92,6 +94,28 @@ impl ErrorResponse { ) } + /// Rate Limit Exceeded + /// Return this error when the request is rejected due to rate limiting. + pub fn rate_limit_exceeded(msg: &str) -> (StatusCode, Json) { + ( + StatusCode::TOO_MANY_REQUESTS, + Json(ErrorResponse { + error: msg.to_string(), + }), + ) + } + + /// Bad Request + /// Return this error when the received request is malformed. + pub fn bad_request(msg: &str) -> (StatusCode, Json) { + ( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + error: msg.to_string(), + }), + ) + } + /// The OAI endpoints call an [`dynamo.runtime::engine::AsyncEngine`] which are specialized to return /// an [`anyhow::Error`]. This method will convert the [`anyhow::Error`] into an [`HttpError`]. /// If successful, it will return the [`HttpError`] as an [`ErrorResponse::internal_server_error`] @@ -121,6 +145,35 @@ impl From for ErrorResponse { } } +/// Rate Limit Middleware +/// +/// This middleware will check if the current request should be rejected based on the rate limiter logic. +/// The rate limiter logic, on the other hand, keeps track of specific model metrics, such as TTFT and ITL. +/// If these values exceed specific configured thresholds, the request will be rejected. This is so that +/// we can optimize for goodput, and not necessarily raw throughput, across the service. +/// +/// If the request should be rejected, it will return a Status Code of 429 Too Many Requests. +/// Otherwise, it will call the next middleware/route handler. +pub async fn rate_limit_middleware( + State(state): State, + request: Request, + next: Next, +) -> Result)> { + let request_body = request.body(); + let model = request_body.inner.model.clone(); + let should_reject = state.rate_limiter().should_reject(&model).map_err(|_| { + ErrorResponse::internal_server_error("Failed to check rate limit") + })?; + + if should_reject { + return Err(ErrorResponse::rate_limit_exceeded(&format!( + "Rate limit exceeded for current request and model: {model}. Please retry later." + ))); + } + + Ok(next.run(request).await) +} + /// OpenAI Completions Request Handler /// /// This method will handle the incoming request for the `/v1/completions endpoint`. The endpoint is a "source" diff --git a/lib/llm/src/http/service/service_v2.rs b/lib/llm/src/http/service/service_v2.rs index 71524db363..09ef548105 100644 --- a/lib/llm/src/http/service/service_v2.rs +++ b/lib/llm/src/http/service/service_v2.rs @@ -8,6 +8,8 @@ use super::metrics; use super::Metrics; use super::RouteDoc; use crate::discovery::ModelManager; +use crate::http::service::rate_limiter::RateLimiter; +use crate::http::service::rate_limiter::RateLimiterConfig; use crate::request_template::RequestTemplate; use anyhow::Result; use derive_builder::Builder; @@ -18,13 +20,15 @@ use tokio_util::sync::CancellationToken; pub struct State { metrics: Arc, manager: Arc, + rate_limiter: Arc, } impl State { - pub fn new(manager: Arc) -> Self { + pub fn new(manager: Arc, rate_limiter: Arc) -> Self { Self { manager, metrics: Arc::new(Metrics::default()), + rate_limiter, } } @@ -41,6 +45,14 @@ impl State { self.manager.clone() } + pub fn rate_limiter(&self) -> &RateLimiter { + Arc::as_ref(&self.rate_limiter) + } + + pub fn rate_limiter_clone(&self) -> Arc { + self.rate_limiter.clone() + } + // TODO pub fn sse_keep_alive(&self) -> Option { None @@ -83,6 +95,9 @@ pub struct HttpServiceConfig { #[builder(default = "None")] request_template: Option, + + #[builder(default = "None")] + rate_limiter_config: Option, } impl HttpService { @@ -137,7 +152,10 @@ impl HttpServiceConfigBuilder { let config: HttpServiceConfig = self.build_internal()?; let model_manager = Arc::new(ModelManager::new()); - let state = Arc::new(State::new(model_manager)); + let rate_limiter = Arc::new(RateLimiter::new( + config.rate_limiter_config.unwrap_or_default(), + )); + let state = Arc::new(State::new(model_manager, rate_limiter)); // enable prometheus metrics let registry = metrics::Registry::new(); From c7af6331db767756c63da768a7da41c8bf326ba7 Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Fri, 11 Jul 2025 19:09:25 +0100 Subject: [PATCH 02/11] rate limit logic for percentile real time metrics extraction --- lib/llm/src/http/service/openai.rs | 53 ++--- lib/llm/src/http/service/rate_limiter.rs | 248 +++++++++++++++++++++++ 2 files changed, 270 insertions(+), 31 deletions(-) create mode 100644 lib/llm/src/http/service/rate_limiter.rs diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index 41b1d70b19..f73801543e 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -11,12 +11,10 @@ use std::{ use axum::{ extract::State, http::StatusCode, - middleware::Next, response::{ sse::{Event, KeepAlive, Sse}, IntoResponse, Response, }, - http::Request, routing::{get, post}, Json, Router, }; @@ -145,35 +143,6 @@ impl From for ErrorResponse { } } -/// Rate Limit Middleware -/// -/// This middleware will check if the current request should be rejected based on the rate limiter logic. -/// The rate limiter logic, on the other hand, keeps track of specific model metrics, such as TTFT and ITL. -/// If these values exceed specific configured thresholds, the request will be rejected. This is so that -/// we can optimize for goodput, and not necessarily raw throughput, across the service. -/// -/// If the request should be rejected, it will return a Status Code of 429 Too Many Requests. -/// Otherwise, it will call the next middleware/route handler. -pub async fn rate_limit_middleware( - State(state): State, - request: Request, - next: Next, -) -> Result)> { - let request_body = request.body(); - let model = request_body.inner.model.clone(); - let should_reject = state.rate_limiter().should_reject(&model).map_err(|_| { - ErrorResponse::internal_server_error("Failed to check rate limit") - })?; - - if should_reject { - return Err(ErrorResponse::rate_limit_exceeded(&format!( - "Rate limit exceeded for current request and model: {model}. Please retry later." - ))); - } - - Ok(next.run(request).await) -} - /// OpenAI Completions Request Handler /// /// This method will handle the incoming request for the `/v1/completions endpoint`. The endpoint is a "source" @@ -211,6 +180,17 @@ async fn completions( // todo - when optional, if none, apply a default let model = &request.inner.model; + // Rate limit check + let should_reject = state + .rate_limiter() + .should_reject(model) + .map_err(|_| ErrorResponse::internal_server_error("Failed to check rate limit"))?; + if should_reject { + return Err(ErrorResponse::rate_limit_exceeded(&format!( + "Rate limit exceeded for current request and model: {model}. Please retry later." + ))); + } + // todo - error handling should be more robust let engine = state .manager() @@ -387,6 +367,17 @@ async fn chat_completions( // todo - when optional, if none, apply a default let model = &request.inner.model; + // Rate limit check + let should_reject = state + .rate_limiter() + .should_reject(model) + .map_err(|_| ErrorResponse::internal_server_error("Failed to check rate limit"))?; + if should_reject { + return Err(ErrorResponse::rate_limit_exceeded(&format!( + "Rate limit exceeded for current request and model: {model}. Please retry later." + ))); + } + // todo - determine the proper error code for when a request model is not present tracing::trace!("Getting chat completions engine for model: {}", model); diff --git a/lib/llm/src/http/service/rate_limiter.rs b/lib/llm/src/http/service/rate_limiter.rs new file mode 100644 index 0000000000..0f038abf71 --- /dev/null +++ b/lib/llm/src/http/service/rate_limiter.rs @@ -0,0 +1,248 @@ +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, + time::{Duration, Instant}, +}; + +use anyhow::{Context, Result}; +use hdrhistogram::Histogram; + +const MODEL_METRICS_PRECISION: u8 = 6; + +#[derive(Debug, Clone)] +pub struct RateLimiterConfig { + enabled: bool, + ttft_threshold_ms: f64, + ttft_percentile: f64, + itl_threshold_ms: f64, + itl_percentile: f64, + window_duration: Duration, + per_model_limits: bool, +} + +impl Default for RateLimiterConfig { + fn default() -> Self { + Self { + enabled: true, + ttft_threshold_ms: 1000.0, // 1s + itl_threshold_ms: 10.0, // 10ms + ttft_percentile: 0.95, // 95th percentile + itl_percentile: 0.95, // 95th percentile + window_duration: Duration::from_secs(5), + per_model_limits: false, + } + } +} + +#[derive(Debug)] +pub struct WindowedHistogram { + current: Histogram, + previous: Option>, + window_start: Instant, + window_duration: Duration, + /// Max value in microseconds + max_value: u64, + precision: u8, +} + +impl WindowedHistogram { + pub fn new(window_duration: Duration, max_value: u64, precision: u8) -> Result { + let histogram = Histogram::::new_with_bounds(1, max_value, precision) + .context("Failed to create histogram")?; + + Ok(Self { + current: histogram, + previous: None, + window_start: Instant::now(), + window_duration, + max_value, + precision, + }) + } + + fn record_value(&mut self, value: f64) -> Result<()> { + self.maybe_rotate_window()?; + + // Convert to microseconds for better precision + let value_us = (value * 1000.0).round() as u64; + let clamped_value = value_us.min(self.max_value).max(1); + self.current + .record(clamped_value) + .context("Failed to record value to histogram")?; + + Ok(()) + } + + fn get_percentile(&self, percentile: f64) -> Result { + let sample_count = self.sample_count(); + + if sample_count == 0 { + return Ok(0.0); + } + + let percentile_us = self.current.value_at_percentile(percentile) as f64; + + // Convert back from microseconds to milliseconds + Ok(percentile_us / 1000.0) + } + + fn sample_count(&self) -> u64 { + self.current.len() + self.previous.as_ref().map_or(0, |h| h.len()) + } + + fn maybe_rotate_window(&mut self) -> Result<()> { + let now = Instant::now(); + if now.duration_since(self.window_start) > self.window_duration { + let new_histogram = + Histogram::::new_with_bounds(1, self.max_value, self.precision) + .context("Failed to create new histogram")?; + + self.previous = Some(std::mem::replace(&mut self.current, new_histogram)); + self.window_start = now; + } + + Ok(()) + } +} + +#[derive(Debug)] +struct ModelMetrics { + ttft_histogram: WindowedHistogram, + itl_histogram: WindowedHistogram, +} + +impl ModelMetrics { + fn new(config: &RateLimiterConfig) -> Result { + let ttft_histogram = WindowedHistogram::new( + config.window_duration, + (config.ttft_threshold_ms * 1000.0).round() as u64, // Convert to microseconds + MODEL_METRICS_PRECISION, + )?; + let itl_histogram = WindowedHistogram::new( + config.window_duration, + (config.itl_threshold_ms * 1000.0).round() as u64, // Convert to microseconds + MODEL_METRICS_PRECISION, + )?; + + Ok(Self { + ttft_histogram, + itl_histogram, + }) + } +} + +pub struct RateLimiter { + config: RateLimiterConfig, + // TODO: Can make this a `DashMap` to avoid the need to lock the entire map + model_metrics: Arc>>, +} + +impl RateLimiter { + pub fn new(config: RateLimiterConfig) -> Self { + Self { + config, + model_metrics: Arc::new(RwLock::new(HashMap::new())), + } + } + + #[inline] + fn get_model_key(&self, model: &str) -> String { + if self.config.per_model_limits { + model.to_string() + } else { + "global".to_string() + } + } + + /// Record the time to first token metric for a given model + pub fn record_ttft(&self, model: &str, ttft_ms: f64) -> Result<()> { + if !self.config.enabled { + return Ok(()); + } + + let model_key = self.get_model_key(model); + let mut metrics = self.model_metrics.write().unwrap(); + let model_metrics = metrics + .entry(model_key) + .or_insert_with(|| ModelMetrics::new(&self.config).unwrap()); + + model_metrics + .ttft_histogram + .record_value(ttft_ms) + .context("Failed to record time to first token metric") + } + + /// Record the inter-token latency metric for a given model + pub fn record_itl(&self, model: &str, itl_ms: f64) -> Result<()> { + if !self.config.enabled { + return Ok(()); + } + + let model_key = self.get_model_key(model); + let mut metrics = self.model_metrics.write().unwrap(); + let model_metrics = metrics + .entry(model_key) + .or_insert_with(|| ModelMetrics::new(&self.config).unwrap()); + + model_metrics + .itl_histogram + .record_value(itl_ms) + .context("Failed to record inter-token latency metric") + } + + /// Check if the request should be rejected based on the cached metrics + /// + /// Returns true if the request should be rejected, false otherwise + pub fn should_reject(&self, model: &str) -> Result { + if !self.config.enabled { + return Ok(false); + } + + let model_key = self.get_model_key(model); + let metrics = self.model_metrics.write().unwrap(); + + let Some(model_metrics) = metrics.get(&model_key) else { + return Ok(false); + }; + + let ttft_percentile_ms = model_metrics + .ttft_histogram + .get_percentile(self.config.ttft_percentile)?; + let itl_percentile_ms = model_metrics + .itl_histogram + .get_percentile(self.config.itl_percentile)?; + + let ttft_samples = model_metrics.ttft_histogram.sample_count(); + let itl_samples = model_metrics.itl_histogram.sample_count(); + + // Don't reject if we don't have enough samples + if ttft_samples == 0 || itl_samples == 0 { + return Ok(false); + } + + let ttft_exceeded = self.config.ttft_threshold_ms <= ttft_percentile_ms; + let itl_exceeded = self.config.itl_threshold_ms <= itl_percentile_ms; + + if ttft_exceeded || itl_exceeded { + tracing::warn!( + model = model, + ttft_threshold_ms = self.config.ttft_threshold_ms, + itl_threshold_ms = self.config.itl_threshold_ms, + "Rate limit exceeded for model {model}: ttft: {ttft_percentile_ms}ms, itl: {itl_percentile_ms}ms", + ttft_percentile_ms = ttft_percentile_ms, + itl_percentile_ms = itl_percentile_ms, + ); + return Ok(true); + } + + Ok(false) + } + + pub fn clear_model_metrics(&self, model: &str) -> Result<()> { + let model_key = self.get_model_key(model); + let mut metrics = self.model_metrics.write().unwrap(); + metrics.remove(&model_key); + + Ok(()) + } +} From 2eed9b219ecba5d6f456668179578d067fc78ed2 Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Tue, 15 Jul 2025 15:23:16 +0100 Subject: [PATCH 03/11] add further rate limiter logic for weighted moving averages --- Cargo.lock | 38 +- lib/llm/Cargo.toml | 8 +- lib/llm/benches/rate_limiter.rs | 341 +++++++ lib/llm/src/http/service/metrics.rs | 72 +- lib/llm/src/http/service/openai.rs | 70 +- lib/llm/src/http/service/rate_limiter.rs | 1192 +++++++++++++++++++--- lib/llm/src/http/service/service_v2.rs | 4 +- 7 files changed, 1502 insertions(+), 223 deletions(-) create mode 100644 lib/llm/benches/rate_limiter.rs diff --git a/Cargo.lock b/Cargo.lock index 7250793a26..a87133aeb4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -536,12 +536,6 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" -[[package]] -name = "base64" -version = "0.21.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" - [[package]] name = "base64" version = "0.22.1" @@ -1453,6 +1447,20 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.9.0" @@ -1794,6 +1802,7 @@ dependencies = [ "chrono", "criterion", "cudarc 0.16.2", + "dashmap 6.1.0", "derive-getters", "derive_builder", "dialoguer", @@ -1804,7 +1813,6 @@ dependencies = [ "futures", "galil-seiferas", "ggus", - "hdrhistogram", "hf-hub", "humantime", "insta", @@ -2868,20 +2876,6 @@ dependencies = [ "foldhash", ] -[[package]] -name = "hdrhistogram" -version = "7.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "765c9198f173dd59ce26ff9f95ef0aafd0a0fe01fb9d72841bc5066a4c06511d" -dependencies = [ - "base64 0.21.7", - "byteorder", - "crossbeam-channel", - "flate2", - "nom", - "num-traits", -] - [[package]] name = "heck" version = "0.4.1" @@ -8504,7 +8498,7 @@ dependencies = [ "asynchronous-codec", "bytes", "crossbeam-queue", - "dashmap", + "dashmap 5.5.3", "futures-channel", "futures-io", "futures-task", diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index 988a0d21d7..1fccf5ff05 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -36,6 +36,10 @@ testing-nixl = ["dep:nixl-sys"] block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:ndarray", "dep:nix"] sentencepiece = ["dep:sentencepiece"] +[[bench]] +name = "rate_limiter" +harness = false + [[bench]] name = "tokenizer" harness = false @@ -108,10 +112,8 @@ tokenizers = { version = "0.21.1", default-features = false, features = [ ] } sentencepiece = { version = "0.11.2", optional = true } -# metrics -hdrhistogram = "7.5.4" - # backend +dashmap = { version = "6.1.0" } galil-seiferas = { version = "0.1" } toktrie = { version = "0.6.28" } toktrie_hf_tokenizers = { version = "0.6.28" } diff --git a/lib/llm/benches/rate_limiter.rs b/lib/llm/benches/rate_limiter.rs new file mode 100644 index 0000000000..0d6bacd7d3 --- /dev/null +++ b/lib/llm/benches/rate_limiter.rs @@ -0,0 +1,341 @@ +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use dynamo_llm::http::service::rate_limiter::{ + RateLimiter, RateLimiterConfig, TimeWeightedAverageTracker, +}; +use std::sync::Arc; +use std::thread; +use std::time::Duration; + +// Benchmark configurations +const SAMPLE_SIZES: &[usize] = &[10, 100, 1000, 10000]; +const TIME_CONSTANTS: &[f64] = &[1.0, 10.0, 30.0, 60.0]; +const THREAD_COUNTS: &[usize] = &[1, 2, 4, 8, 16]; + +/// Benchmark recording single values to the tracker +fn bench_record_value(c: &mut Criterion) { + let mut group = c.benchmark_group("record_value"); + + for &sample_size in SAMPLE_SIZES { + group.throughput(Throughput::Elements(sample_size as u64)); + + group.bench_with_input( + BenchmarkId::new("sequential", sample_size), + &sample_size, + |b, &size| { + b.iter(|| { + let mut tracker = TimeWeightedAverageTracker::new(10.0); + for i in 0..size { + tracker.record_value(black_box(i as f64)); + } + }); + }, + ); + } + group.finish(); +} + +/// Benchmark computing time-weighted averages +fn bench_time_weighted_average(c: &mut Criterion) { + let mut group = c.benchmark_group("time_weighted_average"); + + for &sample_size in SAMPLE_SIZES { + group.throughput(Throughput::Elements(1)); // One calculation per iteration + + group.bench_with_input( + BenchmarkId::new("computation", sample_size), + &sample_size, + |b, &size| { + // Pre-populate tracker with samples + let mut tracker = TimeWeightedAverageTracker::new(10.0); + for i in 0..size { + tracker.record_value(i as f64); + if i % 100 == 0 { + // Add some time variance + thread::sleep(Duration::from_nanos(1)); + } + } + + b.iter(|| { + black_box(tracker.get_decayed_time_weighted_average()); + }); + }, + ); + } + group.finish(); +} + +/// Benchmark different time constants impact on performance +fn bench_time_constants(c: &mut Criterion) { + let mut group = c.benchmark_group("time_constants"); + + const SAMPLE_SIZE: usize = 1000; + + for &time_constant in TIME_CONSTANTS { + group.bench_with_input( + BenchmarkId::new("record_and_compute", time_constant), + &time_constant, + |b, &tc| { + b.iter(|| { + let mut tracker = TimeWeightedAverageTracker::new(tc); + + // Record samples + for i in 0..SAMPLE_SIZE { + tracker.record_value(black_box(i as f64)); + } + + // Compute average + black_box(tracker.get_decayed_time_weighted_average()); + }); + }, + ); + } + group.finish(); +} + +/// Benchmark rate limiter decision making +fn bench_rate_limiter_decisions(c: &mut Criterion) { + let mut group = c.benchmark_group("rate_limiter_decisions"); + + let config = RateLimiterConfig::new(100.0, 10.0, 10.0, false); + + group.bench_function("should_reject_with_data", |b| { + let rate_limiter = RateLimiter::new(Some(config.clone())); + + // Pre-populate with samples + for i in 0..100 { + rate_limiter.record_ttft("test-model", 50.0 + i as f64); + rate_limiter.record_itl("test-model", 5.0 + (i as f64 / 10.0)); + } + + b.iter(|| { + black_box(rate_limiter.should_reject(black_box("test-model"))); + }); + }); + + group.bench_function("record_ttft", |b| { + let rate_limiter = RateLimiter::new(Some(config.clone())); + let mut counter = 0; + + b.iter(|| { + rate_limiter.record_ttft(black_box("test-model"), black_box(counter as f64)); + counter += 1; + }); + }); + + group.bench_function("record_itl", |b| { + let rate_limiter = RateLimiter::new(Some(config.clone())); + let mut counter = 0; + + b.iter(|| { + rate_limiter.record_itl(black_box("test-model"), black_box(counter as f64)); + counter += 1; + }); + }); + + group.finish(); +} + +/// Benchmark concurrent access patterns +fn bench_concurrent_access(c: &mut Criterion) { + let mut group = c.benchmark_group("concurrent_access"); + + for &thread_count in THREAD_COUNTS { + group.throughput(Throughput::Elements(thread_count as u64 * 100)); + + group.bench_with_input( + BenchmarkId::new("multi_thread_records", thread_count), + &thread_count, + |b, &num_threads| { + b.iter(|| { + let config = RateLimiterConfig::new(1000.0, 10.0, 30.0, false); + let rate_limiter = Arc::new(RateLimiter::new(Some(config))); + + let handles: Vec<_> = (0..num_threads) + .map(|thread_id| { + let limiter = rate_limiter.clone(); + thread::spawn(move || { + for i in 0..100 { + let value = (thread_id * 100 + i) as f64; + limiter.record_ttft("test-model", value); + limiter.record_itl("test-model", value / 10.0); + + // Some threads check rejection status + if i % 10 == 0 { + black_box(limiter.should_reject("test-model")); + } + } + }) + }) + .collect(); + + for handle in handles { + handle.join().unwrap(); + } + + black_box(rate_limiter); + }); + }, + ); + } + group.finish(); +} + +/// Benchmark memory allocation patterns +fn bench_memory_patterns(c: &mut Criterion) { + let mut group = c.benchmark_group("memory_patterns"); + + group.bench_function("memory_bounded_growth", |b| { + b.iter(|| { + let mut tracker = TimeWeightedAverageTracker::new(10.0); + + // Add way more samples than max_samples to test memory bounds + for i in 0..1000 { + tracker.record_value(black_box(i as f64)); + + // Occasionally compute average to trigger cleanup + if i % 50 == 0 { + black_box(tracker.get_decayed_time_weighted_average()); + } + } + + black_box(tracker); + }); + }); + + group.bench_function("per_model_isolation", |b| { + let config = RateLimiterConfig::new(1000.0, 10.0, 30.0, true); + + b.iter(|| { + let rate_limiter = RateLimiter::new(Some(config.clone())); + + // Simulate multiple models + for model_id in 0..10 { + let model_name = format!("model-{}", model_id); + for i in 0..50 { + rate_limiter.record_ttft(&model_name, i as f64); + rate_limiter.record_itl(&model_name, (i as f64) / 10.0); + } + black_box(rate_limiter.should_reject(&model_name)); + } + + black_box(rate_limiter); + }); + }); + + group.finish(); +} + +/// Benchmark edge cases and stress scenarios +fn bench_edge_cases(c: &mut Criterion) { + let mut group = c.benchmark_group("edge_cases"); + + group.bench_function("rapid_fire_records", |b| { + b.iter(|| { + let mut tracker = TimeWeightedAverageTracker::new(1.0); + + // Rapid fire recording without any delays + for i in 0..5000 { + tracker.record_value(black_box(i as f64)); + } + + black_box(tracker.get_decayed_time_weighted_average()); + }); + }); + + group.bench_function("alternating_high_low_values", |b| { + b.iter(|| { + let mut tracker = TimeWeightedAverageTracker::new(5.0); + + // Alternating between very high and very low values + for i in 0..500 { + let value = if i % 2 == 0 { 1000000.0 } else { 0.001 }; + tracker.record_value(black_box(value)); + } + + black_box(tracker.get_decayed_time_weighted_average()); + }); + }); + + group.bench_function("very_old_samples", |b| { + b.iter(|| { + let mut tracker = TimeWeightedAverageTracker::new(0.1); // Very short time constant + + // Add some samples + for i in 0..100 { + tracker.record_value(black_box(i as f64)); + } + + // Sleep to make them very old + thread::sleep(Duration::from_millis(100)); + + // Add fresh samples + for i in 100..200 { + tracker.record_value(black_box(i as f64)); + } + + black_box(tracker.get_decayed_time_weighted_average()); + }); + }); + + group.finish(); +} + +/// Comprehensive benchmark comparing different configurations +fn bench_configuration_comparison(c: &mut Criterion) { + let mut group = c.benchmark_group("configuration_comparison"); + + let configs = vec![ + ( + "aggressive", + RateLimiterConfig::new(1000.0, 10.0, 1.0, false), + ), + ( + "balanced", + RateLimiterConfig::new(1000.0, 10.0, 10.0, false), + ), + ( + "conservative", + RateLimiterConfig::new(1000.0, 10.0, 60.0, false), + ), + ]; + + for (name, config) in configs { + group.bench_with_input( + BenchmarkId::new("full_workflow", name), + &config, + |b, config| { + b.iter(|| { + let rate_limiter = RateLimiter::new(Some(config.clone())); + + // Simulate realistic usage pattern + for i in 0..200 { + rate_limiter.record_ttft("model", black_box(50.0 + (i as f64))); + rate_limiter.record_itl("model", black_box(5.0 + (i as f64 / 10.0))); + + if i % 20 == 0 { + black_box(rate_limiter.should_reject("model")); + } + } + + black_box(rate_limiter); + }); + }, + ); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_record_value, + bench_time_weighted_average, + bench_time_constants, + bench_rate_limiter_decisions, + bench_concurrent_access, + bench_memory_patterns, + bench_edge_cases, + bench_configuration_comparison +); + +criterion_main!(benches); diff --git a/lib/llm/src/http/service/metrics.rs b/lib/llm/src/http/service/metrics.rs index 181763895c..212364b042 100644 --- a/lib/llm/src/http/service/metrics.rs +++ b/lib/llm/src/http/service/metrics.rs @@ -10,6 +10,8 @@ use std::{ pub use prometheus::Registry; +use crate::http::service::rate_limiter::RateLimiter; + use super::RouteDoc; /// Value for the `status` label in the request counter for successful requests @@ -25,6 +27,7 @@ pub const REQUEST_TYPE_STREAM: &str = "stream"; pub const REQUEST_TYPE_UNARY: &str = "unary"; pub struct Metrics { + rate_limit_requests_counter: IntCounterVec, request_counter: IntCounterVec, inflight_gauge: IntGaugeVec, request_duration: HistogramVec, @@ -81,6 +84,7 @@ pub enum Status { /// Track response-specific metrics pub struct ResponseMetricCollector { metrics: Arc, + rate_limiter: Arc, model: String, start_time: Instant, // we use is_first_token to distinguish TTFT from ITL. It is true by default and @@ -109,6 +113,15 @@ impl Metrics { /// - `{prefix}_http_service_time_to_first_token_seconds` - HistogramVec for time to first token in seconds /// - `{prefix}_http_service_inter_token_latency_seconds` - HistogramVec for inter-token latency in seconds pub fn new(prefix: &str) -> Self { + let rate_limit_requests_counter = IntCounterVec::new( + Opts::new( + format!("{}_http_service_rate_limit_requests_total", prefix), + "Total number of requests rejected by the rate limiter", + ), + &["model", "endpoint", "request_type", "status"], + ) + .unwrap(); + let request_counter = IntCounterVec::new( Opts::new( format!("{}_http_service_requests_total", prefix), @@ -190,6 +203,7 @@ impl Metrics { .unwrap(); Metrics { + rate_limit_requests_counter, request_counter, inflight_gauge, request_duration, @@ -200,6 +214,28 @@ impl Metrics { } } + /// Get the number of requests rejected by the rate limiter for the given dimensions: + /// - model + /// - endpoint (completions/chat_completions) + /// - request type (unary/stream) + /// - status (success/error) + pub fn get_rate_limit_requests_counter( + &self, + model: &str, + endpoint: &Endpoint, + request_type: &RequestType, + status: &Status, + ) -> u64 { + self.rate_limit_requests_counter + .with_label_values(&[ + model, + endpoint.as_str(), + request_type.as_str(), + status.as_str(), + ]) + .get() + } + /// Get the number of successful requests for the given dimensions: /// - model /// - endpoint (completions/chat_completions) @@ -222,6 +258,28 @@ impl Metrics { .get() } + /// Increment the counter for requests rejected by the rate limiter for the given dimensions: + /// - model + /// - endpoint (completions/chat_completions) + /// - request type (unary/stream) + /// - status (success/error) + pub fn inc_rate_limit_requests_counter( + &self, + model: &str, + endpoint: &Endpoint, + request_type: &RequestType, + status: &Status, + ) { + self.rate_limit_requests_counter + .with_label_values(&[ + model, + endpoint.as_str(), + request_type.as_str(), + status.as_str(), + ]) + .inc() + } + /// Increment the counter for requests for the given dimensions: /// - model /// - endpoint (completions/chat_completions) @@ -258,6 +316,7 @@ impl Metrics { } pub fn register(&self, registry: &Registry) -> Result<(), prometheus::Error> { + registry.register(Box::new(self.rate_limit_requests_counter.clone()))?; registry.register(Box::new(self.request_counter.clone()))?; registry.register(Box::new(self.inflight_gauge.clone()))?; registry.register(Box::new(self.request_duration.clone()))?; @@ -294,8 +353,12 @@ impl Metrics { } /// Create a new [`ResponseMetricCollector`] for collecting per-response metrics (i.e., TTFT, ITL) - pub fn create_response_collector(self: Arc, model: &str) -> ResponseMetricCollector { - ResponseMetricCollector::new(self, model.to_string().to_lowercase()) + pub fn create_response_collector( + self: Arc, + model: &str, + rate_limiter: Arc, + ) -> ResponseMetricCollector { + ResponseMetricCollector::new(self, model.to_string().to_lowercase(), rate_limiter) } } @@ -392,9 +455,10 @@ impl Status { } impl ResponseMetricCollector { - fn new(metrics: Arc, model: String) -> Self { + fn new(metrics: Arc, model: String, rate_limiter: Arc) -> Self { ResponseMetricCollector { metrics, + rate_limiter, model, is_first_token: true, last_response_time: None, @@ -425,6 +489,7 @@ impl ResponseMetricCollector { .time_to_first_token .with_label_values(&[&self.model]) .observe(ttft); + self.rate_limiter.record_ttft(&self.model, ttft); // Publish ISL // TODO: publish ISL as soon as the tokenization process completes @@ -445,6 +510,7 @@ impl ResponseMetricCollector { .with_label_values(&[&self.model]) .observe(itl); } + self.rate_limiter.record_itl(&self.model, itl); } self.last_response_time = Some(current_duration); diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index f73801543e..48849688e9 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -103,17 +103,6 @@ impl ErrorResponse { ) } - /// Bad Request - /// Return this error when the received request is malformed. - pub fn bad_request(msg: &str) -> (StatusCode, Json) { - ( - StatusCode::BAD_REQUEST, - Json(ErrorResponse { - error: msg.to_string(), - }), - ) - } - /// The OAI endpoints call an [`dynamo.runtime::engine::AsyncEngine`] which are specialized to return /// an [`anyhow::Error`]. This method will convert the [`anyhow::Error`] into an [`HttpError`]. /// If successful, it will return the [`HttpError`] as an [`ErrorResponse::internal_server_error`] @@ -159,6 +148,9 @@ async fn completions( // return a 503 if the service is not ready check_ready(&state)?; + // Rate limit check + should_reject_request(&state, &request.inner.model)?; + // todo - extract distributed tracing id and context id from headers let request_id = uuid::Uuid::new_v4().to_string(); @@ -180,17 +172,6 @@ async fn completions( // todo - when optional, if none, apply a default let model = &request.inner.model; - // Rate limit check - let should_reject = state - .rate_limiter() - .should_reject(model) - .map_err(|_| ErrorResponse::internal_server_error("Failed to check rate limit"))?; - if should_reject { - return Err(ErrorResponse::rate_limit_exceeded(&format!( - "Rate limit exceeded for current request and model: {model}. Please retry later." - ))); - } - // todo - error handling should be more robust let engine = state .manager() @@ -202,7 +183,9 @@ async fn completions( .metrics_clone() .create_inflight_guard(model, Endpoint::Completions, streaming); - let mut response_collector = state.metrics_clone().create_response_collector(model); + let mut response_collector = state + .metrics_clone() + .create_response_collector(model, state.rate_limiter_clone()); // setup context // todo - inherit request_id from distributed trace details @@ -324,6 +307,9 @@ async fn chat_completions( // return a 503 if the service is not ready check_ready(&state)?; + // Rate limit check + should_reject_request(&state, &request.inner.model)?; + // Handle unsupported fields - if Some(resp) is returned by // validate_chat_completion_unsupported_fields, // then a field was used that is unsupported. We will log an error message @@ -367,17 +353,6 @@ async fn chat_completions( // todo - when optional, if none, apply a default let model = &request.inner.model; - // Rate limit check - let should_reject = state - .rate_limiter() - .should_reject(model) - .map_err(|_| ErrorResponse::internal_server_error("Failed to check rate limit"))?; - if should_reject { - return Err(ErrorResponse::rate_limit_exceeded(&format!( - "Rate limit exceeded for current request and model: {model}. Please retry later." - ))); - } - // todo - determine the proper error code for when a request model is not present tracing::trace!("Getting chat completions engine for model: {}", model); @@ -391,7 +366,9 @@ async fn chat_completions( .metrics_clone() .create_inflight_guard(model, Endpoint::ChatCompletions, streaming); - let mut response_collector = state.metrics_clone().create_response_collector(model); + let mut response_collector = state + .metrics_clone() + .create_response_collector(model, state.rate_limiter_clone()); // setup context // todo - inherit request_id from distributed trace details @@ -549,7 +526,9 @@ async fn responses( .metrics_clone() .create_inflight_guard(model, Endpoint::Responses, false); - let _response_collector = state.metrics_clone().create_response_collector(model); + let _response_collector = state + .metrics_clone() + .create_response_collector(model, state.rate_limiter_clone()); let request = Context::with_id(request, request_id.clone()); @@ -591,6 +570,25 @@ async fn responses( Ok(Json(response).into_response()) } +pub fn should_reject_request( + state: &Arc, + model: &str, +) -> Result<(), (StatusCode, Json)> { + if !state.rate_limiter().is_enabled() { + return Ok(()); + } + + let should_reject = state.rate_limiter().should_reject(model); + + if should_reject { + return Err(ErrorResponse::rate_limit_exceeded(&format!( + "Rate limit exceeded for current request and model: {model}. Please retry later." + ))); + } + + Ok(()) +} + pub fn validate_response_input_is_text_only( request: &NvCreateResponse, ) -> Option { diff --git a/lib/llm/src/http/service/rate_limiter.rs b/lib/llm/src/http/service/rate_limiter.rs index 0f038abf71..6ceb4cce42 100644 --- a/lib/llm/src/http/service/rate_limiter.rs +++ b/lib/llm/src/http/service/rate_limiter.rs @@ -1,150 +1,263 @@ -use std::{ - collections::HashMap, - sync::{Arc, RwLock}, - time::{Duration, Instant}, -}; +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 -use anyhow::{Context, Result}; -use hdrhistogram::Histogram; +//! Rate limiter implementation for the OpenAI API compatible HTTP service. +//! +//! The rate limiter is used to limit the rate of requests to the HTTP service. +//! It is used to prevent abuse of the service, under heavy load, +//! and to ensure that the service is available to all users. The system values +//! 'good-put' (that is, the throughput of the system under good performance metrics, in the form of +//! time to first token and inter-token latency) over 'throughput' (that is, the total amount of +//! tokens processed), and the rate limiter is used to ensure that the service is available to all +//! users, even under heavy load. +//! +//! The rate limiter is implemented using a time-weighted exponential moving average (EMA). +//! The time-weighted average is computed using the following formula: +//! +//! ```text +//! average = sum(value * weight) / sum(weight) +//! ``` +//! +//! Where `weight` is the weight of the sample based on the age of the sample and the time constant: +//! +//! ```text +//! age = now - record_time +//! weight = exp(-age / time_constant_secs) +//! ``` +//! +//! Where `now` is the current time, `record_time` is the time the sample was recorded, +//! and `time_constant_secs` is the time constant for the time-weighted average. +//! +//! Moreover, we decay the average to account for the time elapsed since the last update. +//! This models "system recovery" during idle time. This is done by multiplying the average by the +//! decay factor: +//! +//! ```text +//! decayed_average = average * exp(-time_elapsed / time_constant_secs) +//! ``` -const MODEL_METRICS_PRECISION: u8 = 6; +use std::time::Instant; -#[derive(Debug, Clone)] +use dashmap::DashMap; +use validator::Validate; + +/// Configuration for the rate limiter +#[derive(Debug, Clone, Validate)] pub struct RateLimiterConfig { - enabled: bool, + /// Threshold for the time to first token metric + #[validate(range(min = 0.0))] ttft_threshold_ms: f64, - ttft_percentile: f64, + /// Threshold for the inter-token latency metric + #[validate(range(min = 0.0))] itl_threshold_ms: f64, - itl_percentile: f64, - window_duration: Duration, + /// Time constant for the time-weighted EMA + #[validate(range(min = 0.001))] + time_constant_secs: f64, + /// Whether to use per-model limits per_model_limits: bool, } +impl RateLimiterConfig { + pub fn new( + ttft_threshold_ms: f64, + itl_threshold_ms: f64, + time_constant_secs: f64, + per_model_limits: bool, + ) -> Self { + Self { + ttft_threshold_ms, + itl_threshold_ms, + time_constant_secs, + per_model_limits, + } + } + + pub fn empty() -> Self { + Self { + ttft_threshold_ms: 0.0, + itl_threshold_ms: 0.0, + time_constant_secs: 0.001, + per_model_limits: false, + } + } +} + impl Default for RateLimiterConfig { fn default() -> Self { Self { - enabled: true, ttft_threshold_ms: 1000.0, // 1s itl_threshold_ms: 10.0, // 10ms - ttft_percentile: 0.95, // 95th percentile - itl_percentile: 0.95, // 95th percentile - window_duration: Duration::from_secs(5), + time_constant_secs: 30.0, // 30s per_model_limits: false, } } } +/// Tracks recent samples to compute time-weighted averages of a metric. Formally, +/// the time-weighted average is defined as: +/// +/// ```text +/// average = sum(value * weight) / sum(weight) +/// ``` +/// +/// Where `weight` is the weight of the sample based on the age of the sample and the time constant: +/// +/// ```text +/// age = now - record_time +/// weight = exp(-age / time_constant_secs) +/// ``` +/// +/// Where `now` is the current time, `record_time` is the time the sample was recorded, +/// and `time_constant_secs` is the time constant for the time-weighted average. +/// In this way, more recent samples have a higher weight than older samples, the latter +/// decaying exponentially towards zero (making it less impactful for the current average calculation). +/// +/// In order to compute the time-weighted average more efficiently, we leverage the well +/// known property of the exponential function: +/// +/// ```text +/// exp(x) = exp(y) * exp(x - y) +/// ``` +/// +/// This allows us to compute the time-weighted average, recursively, in a single pass, +/// (see Markov's property) as follows: +/// +/// ```text +/// previous_weight_total = sum(weight) +/// updated_factor = 1 / (1 + previous_weight_total * exp(-age / time_constant_secs)) +/// average(now) = average(last_time) * updated_factor + value * (1 - updated_factor) +/// ``` #[derive(Debug)] -pub struct WindowedHistogram { - current: Histogram, - previous: Option>, - window_start: Instant, - window_duration: Duration, - /// Max value in microseconds - max_value: u64, - precision: u8, +pub struct TimeWeightedAverageTracker { + /// Last computed time-weighted average + last_weighted_average: f64, + /// Last total weight sum + last_total_weight: f64, + /// Last observed time + last_time: Instant, + /// Time constant for the time-weighted average + time_constant_secs: f64, } -impl WindowedHistogram { - pub fn new(window_duration: Duration, max_value: u64, precision: u8) -> Result { - let histogram = Histogram::::new_with_bounds(1, max_value, precision) - .context("Failed to create histogram")?; - - Ok(Self { - current: histogram, - previous: None, - window_start: Instant::now(), - window_duration, - max_value, - precision, - }) - } - - fn record_value(&mut self, value: f64) -> Result<()> { - self.maybe_rotate_window()?; - - // Convert to microseconds for better precision - let value_us = (value * 1000.0).round() as u64; - let clamped_value = value_us.min(self.max_value).max(1); - self.current - .record(clamped_value) - .context("Failed to record value to histogram")?; - - Ok(()) +impl TimeWeightedAverageTracker { + pub fn new(time_constant_secs: f64) -> Self { + let now = Instant::now(); + Self { + last_weighted_average: 0., + last_total_weight: 0., + last_time: now, + time_constant_secs, + } } - fn get_percentile(&self, percentile: f64) -> Result { - let sample_count = self.sample_count(); + /// Record a new value to the tracker. + pub fn record_value(&mut self, value: f64) { + let now = Instant::now(); + if self.last_weighted_average == 0. && self.last_total_weight == 0. { + // First sample + self.last_weighted_average = value; + self.last_total_weight = 1.; + } else { + let time_elapsed = now.duration_since(self.last_time).as_secs_f64(); + let decay_factor = (-time_elapsed / self.time_constant_secs).exp(); - if sample_count == 0 { - return Ok(0.0); + // Update the weighted average, using recursive EMA formula + self.last_total_weight = 1. + self.last_total_weight * decay_factor; + let alpha = 1. / self.last_total_weight; + self.last_weighted_average = alpha * value + (1. - alpha) * self.last_weighted_average; } - let percentile_us = self.current.value_at_percentile(percentile) as f64; - - // Convert back from microseconds to milliseconds - Ok(percentile_us / 1000.0) + self.last_time = now; } - fn sample_count(&self) -> u64 { - self.current.len() + self.previous.as_ref().map_or(0, |h| h.len()) + /// Get the current time-weighted average, decayed to account for the time elapsed since the last update. + pub fn get_decayed_time_weighted_average(&self) -> f64 { + let now = Instant::now(); + let time_elapsed = now.duration_since(self.last_time).as_secs_f64(); + let decay_factor = (-time_elapsed / self.time_constant_secs).exp(); + self.last_weighted_average * decay_factor } +} - fn maybe_rotate_window(&mut self) -> Result<()> { - let now = Instant::now(); - if now.duration_since(self.window_start) > self.window_duration { - let new_histogram = - Histogram::::new_with_bounds(1, self.max_value, self.precision) - .context("Failed to create new histogram")?; +#[derive(Debug)] +struct ModelMetrics { + ttft_tracker: TimeWeightedAverageTracker, + itl_tracker: TimeWeightedAverageTracker, +} + +impl ModelMetrics { + fn new(config: &RateLimiterConfig) -> Self { + let ttft_tracker = TimeWeightedAverageTracker::new(config.time_constant_secs); + let itl_tracker = TimeWeightedAverageTracker::new(config.time_constant_secs); - self.previous = Some(std::mem::replace(&mut self.current, new_histogram)); - self.window_start = now; + Self { + ttft_tracker, + itl_tracker, } + } +} + +#[derive(Debug, Clone)] +pub struct RateLimiterMetrics { + pub ttft_diagnostics: TimeWeightedDiagnostics, + pub itl_diagnostics: TimeWeightedDiagnostics, +} - Ok(()) +impl std::fmt::Display for RateLimiterMetrics { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "RateLimiterMetrics {{\n TTFT: {},\n ITL: {}\n}}", + self.ttft_diagnostics, self.itl_diagnostics + ) } } -#[derive(Debug)] -struct ModelMetrics { - ttft_histogram: WindowedHistogram, - itl_histogram: WindowedHistogram, +#[derive(Debug, Clone)] +pub struct TimeWeightedDiagnostics { + pub decayed_time_weighted_average: f64, + pub time_constant_secs: f64, + pub last_weighted_sum: f64, + pub last_time: Instant, } -impl ModelMetrics { - fn new(config: &RateLimiterConfig) -> Result { - let ttft_histogram = WindowedHistogram::new( - config.window_duration, - (config.ttft_threshold_ms * 1000.0).round() as u64, // Convert to microseconds - MODEL_METRICS_PRECISION, - )?; - let itl_histogram = WindowedHistogram::new( - config.window_duration, - (config.itl_threshold_ms * 1000.0).round() as u64, // Convert to microseconds - MODEL_METRICS_PRECISION, - )?; - - Ok(Self { - ttft_histogram, - itl_histogram, - }) +impl std::fmt::Display for TimeWeightedDiagnostics { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "TimeWeightedDiagnostics {{ \ + decayed_time_weighted_average: {:.3}, \ + time_constant_secs: {:.1}, \ + last_weighted_sum: {:.3}, \ + duration_since_last_update: {:?} \ + }}", + self.decayed_time_weighted_average, + self.time_constant_secs, + self.last_weighted_sum, + self.last_time.elapsed().as_secs_f64() + ) } } pub struct RateLimiter { config: RateLimiterConfig, - // TODO: Can make this a `DashMap` to avoid the need to lock the entire map - model_metrics: Arc>>, + model_metrics: DashMap, + is_enabled: bool, } impl RateLimiter { - pub fn new(config: RateLimiterConfig) -> Self { + pub fn new(config: Option) -> Self { Self { - config, - model_metrics: Arc::new(RwLock::new(HashMap::new())), + is_enabled: config.is_some(), + config: config.unwrap_or_else(|| RateLimiterConfig::empty()), + model_metrics: DashMap::new(), } } + pub fn is_enabled(&self) -> bool { + self.is_enabled + } + #[inline] fn get_model_key(&self, model: &str) -> String { if self.config.per_model_limits { @@ -155,94 +268,861 @@ impl RateLimiter { } /// Record the time to first token metric for a given model - pub fn record_ttft(&self, model: &str, ttft_ms: f64) -> Result<()> { - if !self.config.enabled { - return Ok(()); - } - + pub fn record_ttft(&self, model: &str, ttft_ms: f64) { let model_key = self.get_model_key(model); - let mut metrics = self.model_metrics.write().unwrap(); - let model_metrics = metrics + let mut model_metrics = self + .model_metrics .entry(model_key) - .or_insert_with(|| ModelMetrics::new(&self.config).unwrap()); + .or_insert_with(|| ModelMetrics::new(&self.config)); - model_metrics - .ttft_histogram - .record_value(ttft_ms) - .context("Failed to record time to first token metric") + model_metrics.ttft_tracker.record_value(ttft_ms); } /// Record the inter-token latency metric for a given model - pub fn record_itl(&self, model: &str, itl_ms: f64) -> Result<()> { - if !self.config.enabled { - return Ok(()); - } - + pub fn record_itl(&self, model: &str, itl_ms: f64) { let model_key = self.get_model_key(model); - let mut metrics = self.model_metrics.write().unwrap(); - let model_metrics = metrics + let mut model_metrics = self + .model_metrics .entry(model_key) - .or_insert_with(|| ModelMetrics::new(&self.config).unwrap()); + .or_insert_with(|| ModelMetrics::new(&self.config)); - model_metrics - .itl_histogram - .record_value(itl_ms) - .context("Failed to record inter-token latency metric") + model_metrics.itl_tracker.record_value(itl_ms); } /// Check if the request should be rejected based on the cached metrics /// /// Returns true if the request should be rejected, false otherwise - pub fn should_reject(&self, model: &str) -> Result { - if !self.config.enabled { - return Ok(false); - } - + pub fn should_reject(&self, model: &str) -> bool { let model_key = self.get_model_key(model); - let metrics = self.model_metrics.write().unwrap(); + let model_metrics = self.model_metrics.get(&model_key); - let Some(model_metrics) = metrics.get(&model_key) else { - return Ok(false); + let Some(model_metrics) = model_metrics else { + return false; }; - let ttft_percentile_ms = model_metrics - .ttft_histogram - .get_percentile(self.config.ttft_percentile)?; - let itl_percentile_ms = model_metrics - .itl_histogram - .get_percentile(self.config.itl_percentile)?; + // Get decayed time-weighted EMA values + let decayed_ttft_ema = model_metrics + .ttft_tracker + .get_decayed_time_weighted_average(); + let decayed_itl_ema = model_metrics + .itl_tracker + .get_decayed_time_weighted_average(); + + drop(model_metrics); - let ttft_samples = model_metrics.ttft_histogram.sample_count(); - let itl_samples = model_metrics.itl_histogram.sample_count(); + let ttft_exceeded = self.config.ttft_threshold_ms < decayed_ttft_ema; + let itl_exceeded = self.config.itl_threshold_ms < decayed_itl_ema; - // Don't reject if we don't have enough samples - if ttft_samples == 0 || itl_samples == 0 { - return Ok(false); + if ttft_exceeded || itl_exceeded { + let rate_limiter_metrics = self.get_metrics(&model_key); + self.log_metrics(model, rate_limiter_metrics); + return true; } - let ttft_exceeded = self.config.ttft_threshold_ms <= ttft_percentile_ms; - let itl_exceeded = self.config.itl_threshold_ms <= itl_percentile_ms; + false + } + + /// Get current metrics and diagnostics for current model + #[inline] + fn get_metrics(&self, model_key: &str) -> RateLimiterMetrics { + let model_metrics = self.model_metrics.get(model_key).unwrap(); + let decayed_ttft_ema = model_metrics + .ttft_tracker + .get_decayed_time_weighted_average(); + let decayed_itl_ema = model_metrics + .itl_tracker + .get_decayed_time_weighted_average(); + let ttft_last_weighted_sum = model_metrics.ttft_tracker.last_total_weight; + let itl_last_weighted_sum = model_metrics.itl_tracker.last_total_weight; + let ttft_last_time = model_metrics.ttft_tracker.last_time; + let itl_last_time = model_metrics.itl_tracker.last_time; - if ttft_exceeded || itl_exceeded { - tracing::warn!( - model = model, - ttft_threshold_ms = self.config.ttft_threshold_ms, - itl_threshold_ms = self.config.itl_threshold_ms, - "Rate limit exceeded for model {model}: ttft: {ttft_percentile_ms}ms, itl: {itl_percentile_ms}ms", - ttft_percentile_ms = ttft_percentile_ms, - itl_percentile_ms = itl_percentile_ms, + RateLimiterMetrics { + ttft_diagnostics: TimeWeightedDiagnostics { + decayed_time_weighted_average: decayed_ttft_ema, + time_constant_secs: self.config.time_constant_secs, + last_weighted_sum: ttft_last_weighted_sum, + last_time: ttft_last_time, + }, + itl_diagnostics: TimeWeightedDiagnostics { + decayed_time_weighted_average: decayed_itl_ema, + time_constant_secs: self.config.time_constant_secs, + last_weighted_sum: itl_last_weighted_sum, + last_time: itl_last_time, + }, + } + } + + fn log_metrics(&self, model: &str, metrics: RateLimiterMetrics) { + tracing::warn!( + model = model, + ttft_threshold_ms = self.config.ttft_threshold_ms, + itl_threshold_ms = self.config.itl_threshold_ms, + "Rate limit exceeded for model {model}: {metrics}", + metrics = metrics, + ); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::Ordering; + use std::sync::{atomic::AtomicUsize, Arc}; + use std::thread; + use std::time::Duration; + + #[test] + fn test_simple_time_weighted_average_tracker() { + const TIME_CONSTANT_SECS: f64 = 1.0; // Short time constant + + const SLEEP_DURATION_MS: u64 = 100; + + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Add samples with increasing delays + tracker.record_value(100.0); + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); + tracker.record_value(200.0); + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); + tracker.record_value(300.0); + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); + tracker.record_value(400.0); + thread::sleep(Duration::from_millis(20 * SLEEP_DURATION_MS)); // Long gap + tracker.record_value(500.0); + + let avg = tracker.get_decayed_time_weighted_average(); + assert!(avg > 0.0, "Average should be positive"); + } + + #[test] + fn test_edge_case_all_samples_below_threshold() { + const TIME_CONSTANT_SECS: f64 = 0.1; + const SLEEP_DURATION_MS: u64 = 2_000; + const EPSILON: f64 = 1e-5; + + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + tracker.record_value(100.0); + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); + + let avg = tracker.get_decayed_time_weighted_average(); + + // Should return a close to 0.0 when time constant is small and time passed is large + assert!(avg < EPSILON, "Average should be 0.0: {}", avg); + } + + #[test] + fn test_edge_case_single_sample() { + const TIME_CONSTANT_SECS: f64 = 10.; + const EPSILON: f64 = 0.5; // exp(-0.01) ~= 0.99 + + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + tracker.record_value(42.); + thread::sleep(Duration::from_millis(100)); + + let avg = tracker.get_decayed_time_weighted_average(); + + assert!( + (avg - 42.).abs() < EPSILON, + "Average should be close to 42.0: {}", + avg + ); + } + + #[test] + fn test_time_weighted_average_tracker_correctness() { + const TIME_CONSTANT_SECS: f64 = 10.0; + const NUM_SAMPLES: usize = 100; + const SLEEP_DURATION_MS: u64 = 1; + + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Record old sample + tracker.record_value(1000.0); + thread::sleep(Duration::from_millis(1_000 * SLEEP_DURATION_MS)); + + // Add more recent samples with lower values + for _ in 0..NUM_SAMPLES { + tracker.record_value(100.0); + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); + } + + let avg = tracker.get_decayed_time_weighted_average(); + assert!( + avg < 500.0, + "Average should be dominated by recent samples: {}", + avg + ); + assert!( + avg > 100.0, + "Average should still be influenced by old sample: {}", + avg + ); + } + + #[test] + fn test_time_weighted_average_quantitative_analysis() { + const TIME_CONSTANT_SECS: f64 = 2.0; // 2 second time constant + const EPSILON: f64 = 0.05; // 5% tolerance for timing precision + const SLEEP_DURATION_MS: u64 = 100; // 100ms + + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Record samples with known values and controlled timing + let sample_values = vec![100.0, 200.0, 300.0, 400.0]; + let sample_delays_ms = vec![0, 500, 1000, 1500]; // Delays in milliseconds + + let start_time = Instant::now(); + + // Record first sample immediately + tracker.record_value(sample_values[0]); + + // Record subsequent samples with known delays + for i in 1..sample_values.len() { + thread::sleep(Duration::from_millis( + sample_delays_ms[i] - sample_delays_ms[i - 1], + )); + tracker.record_value(sample_values[i]); + } + + // Wait a bit more, then calculate + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); + let calculation_time = Instant::now(); + + // Calculate expected weighted average manually + let total_elapsed = calculation_time.duration_since(start_time); + let mut expected_weighted_sum = 0.0; + let mut expected_total_weight = 0.0; + + for i in 0..sample_values.len() { + // Age of this sample = total_elapsed - delay_when_recorded + let sample_age_secs = + total_elapsed.as_secs_f64() - (sample_delays_ms[i] as f64 / 1000.0); + let weight = f64::exp(-sample_age_secs / TIME_CONSTANT_SECS); + + expected_weighted_sum += sample_values[i] * weight; + expected_total_weight += weight; + + println!( + "Sample {}: value={}, age={:.3}s, weight={:.6}", + i, sample_values[i], sample_age_secs, weight ); - return Ok(true); } - Ok(false) + let expected_average = + (expected_weighted_sum / expected_total_weight) * f64::exp(-0.1 / TIME_CONSTANT_SECS); // 0.1s is the time elapsed since the last sample + let actual_average = tracker.get_decayed_time_weighted_average(); + + println!("Expected average: {:.6}", expected_average); + println!("Actual average: {:.6}", actual_average); + println!( + "Difference: {:.6}", + (actual_average - expected_average).abs() + ); + println!( + "Relative error: {:.4}%", + 100.0 * (actual_average - expected_average).abs() / expected_average + ); + + // Verify the calculation is mathematically correct within tolerance + let relative_error = (actual_average - expected_average).abs() / expected_average; + assert!( + relative_error < EPSILON, + "Time-weighted average calculation error too large: expected {:.6}, got {:.6}, relative error {:.4}%", + expected_average, actual_average, relative_error * 100.0 + ); + + // Additional verification: more recent samples should have higher influence + // Sample 3 (400.0) is most recent, so if we compare with a simple average: + let simple_average = sample_values.iter().sum::() / sample_values.len() as f64; + println!("Simple average: {:.6}", simple_average); + + // The time-weighted average should be closer to the most recent value (400.0) + // than the simple average, since recent samples have higher weights + let distance_to_recent = (actual_average - 400.0).abs(); + let distance_simple_to_recent = (simple_average - 400.0).abs(); + + assert!( + distance_to_recent < distance_simple_to_recent, + "Time-weighted average should be closer to recent values than simple average: \ + weighted_avg={:.2}, simple_avg={:.2}, recent_value=400.0", + actual_average, + simple_average + ); } - pub fn clear_model_metrics(&self, model: &str) -> Result<()> { - let model_key = self.get_model_key(model); - let mut metrics = self.model_metrics.write().unwrap(); - metrics.remove(&model_key); + #[test] + fn test_exponential_decay_verification() { + const TIME_CONSTANT_SECS: f64 = 1.0; // 1 second time constant + const EPSILON: f64 = 0.02; // 2% tolerance + + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Record a high value, then wait exactly one time constant + tracker.record_value(1000.0); + thread::sleep(Duration::from_millis(1_000)); // Wait 1 second = 1 time constant + + // Record a low value + tracker.record_value(100.0); + + let actual_average = tracker.get_decayed_time_weighted_average(); + + // After 1 time constant, the old sample should have weight = e^(-1) ≈ 0.368 + // New sample has weight ≈ 1.0 + let old_weight = f64::exp(-1.0); // ≈ 0.368 + let new_weight = 1.0; + + let expected_average = + (1000.0 * old_weight + 100.0 * new_weight) / (old_weight + new_weight); + + println!("Old weight (e^-1): {:.6}", old_weight); + println!("New weight: {:.6}", new_weight); + println!("Expected average: {:.6}", expected_average); + println!("Actual average: {:.6}", actual_average); + + let relative_error = (actual_average - expected_average).abs() / expected_average; + assert!( + relative_error < EPSILON, + "Exponential decay verification failed: expected {:.6}, got {:.6}, error {:.4}%", + expected_average, + actual_average, + relative_error * 100.0 + ); + + // Verify the theoretical calculation: should be around 463.4 + assert!( + (expected_average - 342.04727).abs() < 1e-5, + "Theoretical calculation seems wrong: {:.1}", + expected_average + ); + } + + #[test] + fn test_mathematical_properties() { + const TIME_CONSTANT_SECS: f64 = 2.0; + + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Property 1: Single sample should return its own value + tracker.record_value(42.0); + let single_avg = tracker.get_decayed_time_weighted_average(); + assert!( + (single_avg - 42.0).abs() < 1e-6, + "Single sample average should equal sample value: {}", + single_avg + ); + + // Property 2: Adding identical samples should not change average + tracker.record_value(42.0); + tracker.record_value(42.0); + let identical_avg = tracker.get_decayed_time_weighted_average(); + assert!( + (identical_avg - 42.0).abs() < 1e-5, + "Identical samples should maintain average: {}", + identical_avg + ); + + // Property 3: Average should be bounded by min and max values + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + let values = vec![10.0, 50.0, 30.0, 70.0, 20.0]; + let min_val = 10.0; + let max_val = 70.0; + + for &val in &values { + tracker.record_value(val); + thread::sleep(Duration::from_millis(10)); + } + + let bounded_avg = tracker.get_decayed_time_weighted_average(); + assert!( + bounded_avg >= min_val && bounded_avg <= max_val, + "Average should be bounded: {:.2} not in [{:.2}, {:.2}]", + bounded_avg, + min_val, + max_val + ); + + println!("Values: {:?}", values); + println!( + "Average: {:.2} ∈ [{:.2}, {:.2}] ✓", + bounded_avg, min_val, max_val + ); + } + + #[test] + fn test_concurrent_access_simulation() { + const NUM_THREADS: usize = 10; + const NUM_RECORDS: usize = 100; + + const SLEEP_INTERVAL: usize = 10; + const SLEEP_DURATION_MS: u64 = 1; + + let config = RateLimiterConfig { + ttft_threshold_ms: 1000.0, + itl_threshold_ms: 100.0, + time_constant_secs: 30.0, + per_model_limits: false, + }; + let limiter = Arc::new(RateLimiter::new(Some(config))); + let error_count = Arc::new(AtomicUsize::new(0)); + + let mut handles = Vec::new(); + + for i in 0..NUM_THREADS { + let limiter_clone = limiter.clone(); + let error_count_clone = error_count.clone(); + + handles.push(thread::spawn(move || { + for j in 0..NUM_RECORDS { + limiter_clone.record_ttft("model", (i * NUM_RECORDS + j) as f64); + limiter_clone.record_itl("model", (i + j) as f64); + + if limiter_clone.should_reject("model") { + error_count_clone.fetch_add(1, Ordering::Relaxed); + } + + if j % SLEEP_INTERVAL == 0 { + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); + } + } + })); + } + + for handle in handles { + handle.join().unwrap(); + } + + // Should have no errors + let error_count = error_count.load(Ordering::Relaxed); + assert_eq!(error_count, 0, "Error count should be 0: {}", error_count); + } + + #[test] + fn test_concurrent_access_simulation_with_error_count() { + const NUM_THREADS: usize = 10; + const NUM_RECORDS: usize = 100; + + const SLEEP_INTERVAL: usize = 10; + const SLEEP_DURATION_MS: u64 = 1; + + let config = RateLimiterConfig { + ttft_threshold_ms: 1000.0, + itl_threshold_ms: 10.0, + time_constant_secs: 30.0, + per_model_limits: false, + }; + let limiter = Arc::new(RateLimiter::new(Some(config))); + let error_count = Arc::new(AtomicUsize::new(0)); + + let mut handles = Vec::new(); + + for i in 0..NUM_THREADS { + let limiter_clone = limiter.clone(); + let error_count_clone = error_count.clone(); + + handles.push(thread::spawn(move || { + for j in 0..NUM_RECORDS { + limiter_clone.record_ttft("model", (i * NUM_RECORDS + j) as f64); + limiter_clone.record_itl("model", (i + j) as f64); + + if limiter_clone.should_reject("model") { + error_count_clone.fetch_add(1, Ordering::Relaxed); + } + + if j % SLEEP_INTERVAL == 0 { + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); + } + } + })); + } + + for handle in handles { + handle.join().unwrap(); + } + + // Roughly 10% of the time, we should have an error, we set the threshold to 12% to account for + // the effect of the time passing. + let error_count = error_count.load(Ordering::Relaxed); + assert!( + error_count > 880 && error_count < 920, + "Error count should be around 12% of the time: {}", + error_count + ); + } + + #[test] + fn test_concurrent_operations() { + use std::sync::Mutex; + + const TIME_CONSTANT_SECS: f64 = 10.0; + + const SLEEP_DURATION_MS: u64 = 1; + + const NUM_THREADS: usize = 5; + const NUM_RECORDS: usize = 20; + + let tracker = Arc::new(Mutex::new(TimeWeightedAverageTracker::new( + TIME_CONSTANT_SECS, + ))); + + let mut handles = Vec::new(); + + // Spawn multiple threads adding values + for thread_id in 0..NUM_THREADS { + let tracker_clone = Arc::clone(&tracker); + let handle = thread::spawn(move || { + for i in 0..NUM_RECORDS { + let value = (thread_id * 100 + i) as f64; + tracker_clone.lock().unwrap().record_value(value); + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); + } + }); + handles.push(handle); + } + + // Also spawn a thread that computes averages + const NUM_AVERAGES: usize = 10; + const SLEEP_DURATION_MS_AVERAGE: u64 = 5; + + let tracker_clone = Arc::clone(&tracker); + let avg_handle = thread::spawn(move || { + for _ in 0..NUM_AVERAGES { + let avg = tracker_clone + .lock() + .unwrap() + .get_decayed_time_weighted_average(); + assert!(avg > 0.0, "Average should be positive"); + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS_AVERAGE)); + } + }); + + // Wait for all threads + for handle in handles { + handle.join().unwrap(); + } + avg_handle.join().unwrap(); + + let final_avg = tracker.lock().unwrap().get_decayed_time_weighted_average(); + assert!(final_avg > 0.0, "Final average should be positive"); + } + + #[test] + fn test_rate_limiter_integration() { + let config = RateLimiterConfig { + ttft_threshold_ms: 100.0, // 100ms + itl_threshold_ms: 5.0, // 5ms + time_constant_secs: 1.0, + ..Default::default() + }; + + let limiter = Arc::new(RateLimiter::new(Some(config))); + + // Record low values - should not trigger + limiter.record_ttft("test", 50.0); + limiter.record_ttft("test", 60.0); + limiter.record_ttft("test", 70.0); + + thread::sleep(Duration::from_millis(150)); // Wait for warmup + + assert!( + !limiter.should_reject("test"), + "Should not reject with low values" + ); + + // Record high values - should trigger + limiter.record_ttft("test", 200.0); + limiter.record_ttft("test", 300.0); + + assert!( + limiter.should_reject("test"), + "Should reject with high values" + ); + } + + #[test] + fn test_rate_limiter_integration_samples_close_to_trigger() { + const NUM_SAMPLES: usize = 100; + + let config = RateLimiterConfig { + ttft_threshold_ms: 100.0, // 100ms + itl_threshold_ms: 5.0, // 5ms + time_constant_secs: 1.0, + ..Default::default() + }; + + let limiter = Arc::new(RateLimiter::new(Some(config))); + + // Record low values - should not trigger + limiter.record_ttft("test", 50.0); + limiter.record_ttft("test", 60.0); + limiter.record_ttft("test", 70.0); + + thread::sleep(Duration::from_millis(150)); // Wait for warmup + + assert!( + !limiter.should_reject("test"), + "Should not reject with low values" + ); + + // Record multiple values close to trigger + for i in 0..NUM_SAMPLES { + limiter.record_ttft("test", 100.0 + i as f64 / 10.0); + } + + assert!( + limiter.should_reject("test"), + "Should reject with high values" + ); + } + + #[test] + fn test_per_model_vs_global_limits() { + const MODEL_A: &str = "model_a"; + const MODEL_B: &str = "model_b"; + + let global_config = RateLimiterConfig { + per_model_limits: false, + ..Default::default() + }; + + let per_model_config = RateLimiterConfig { + per_model_limits: true, + ..Default::default() + }; + + let global_limiter = RateLimiter::new(Some(global_config)); + let per_model_limiter = RateLimiter::new(Some(per_model_config)); + + // Record high values for model A + global_limiter.record_ttft(MODEL_A, 2000.0); + global_limiter.record_ttft(MODEL_A, 2000.0); + + per_model_limiter.record_ttft(MODEL_A, 2000.0); + per_model_limiter.record_ttft(MODEL_A, 2000.0); + + thread::sleep(Duration::from_millis(20)); + + // Both should reject model A + assert!(global_limiter.should_reject(MODEL_A)); + assert!(per_model_limiter.should_reject(MODEL_A)); + + // Global limiter should also reject model B (uses same "global" key) + assert!(global_limiter.should_reject(MODEL_B)); + + // Per-model limiter should NOT reject model B (separate tracking) + assert!(!per_model_limiter.should_reject(MODEL_B)); + } + + #[test] + fn test_numerical_stability_long_time_series() { + // Scenario 1: Very small time constant with long time series + { + const TIME_CONSTANT_SECS: f64 = 0.01; // Very small time constant + const NUM_SAMPLES: usize = 10_000; + const SLEEP_DURATION_MICROS: u64 = 100; // 0.1ms between samples + + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Add many samples with controlled timing + for i in 0..NUM_SAMPLES { + tracker.record_value((i % 100) as f64); // Cycling values 0-99 + + if i % 1000 == 0 { + // Add occasional small delays to test very old samples + thread::sleep(Duration::from_micros(SLEEP_DURATION_MICROS)); + } + } + + let avg = tracker.get_decayed_time_weighted_average(); + + // Should be finite and reasonable + assert!( + avg.is_finite(), + "Average should be finite with small time constant" + ); + assert!( + avg > 0.0 && avg < 100.0, + "Average should be bounded by sample range: {}", + avg + ); + + // With small time constant, should be dominated by recent samples (90-99 range) + assert!( + (avg - 50.0).abs() < 0.5, + "With small time constant, average should reflect recent samples: {}", + avg + ); + } + + // Scenario 2: Extreme value ranges + { + const TIME_CONSTANT_SECS: f64 = 1.0; + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + let extreme_values = vec![ + 1e-10, // Very small positive + 1e10, // Very large + 1e-15, // Tiny + 1e15, // Huge + 0.001, // Small + 1000000.0, // Large + ]; + + for &value in &extreme_values { + tracker.record_value(value); + thread::sleep(Duration::from_millis(1)); + } + + let avg = tracker.get_decayed_time_weighted_average(); + assert!( + avg.is_finite(), + "Average should handle extreme values gracefully" + ); + assert!(avg > 0.0, "Average of positive values should be positive"); + } + + // Scenario 3: Accumulated precision test with repetitive operations + { + const TIME_CONSTANT_SECS: f64 = 10.0; + const NUM_ITERATIONS: usize = 50_000; + + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Record the same value many times to test for accumulated rounding errors + let test_value = 42.42424242424242; // Value with many decimal places + + for _ in 0..NUM_ITERATIONS { + tracker.record_value(test_value); + } + + let avg = tracker.get_decayed_time_weighted_average(); + + // Should be very close to the test value (within reasonable floating point precision) + let relative_error = (avg - test_value).abs() / test_value; + assert!( + relative_error < 1e-6, + "Accumulated rounding error too large: {} vs {}, error: {:.2e}", + avg, + test_value, + relative_error + ); + } + + // Scenario 4: Weight underflow protection + { + const TIME_CONSTANT_SECS: f64 = 0.1; // Small time constant + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Add initial sample + tracker.record_value(1000.0); + + // Wait a very long time (relative to time constant) + thread::sleep(Duration::from_millis(2000)); // 20 time constants + + // Add recent samples - old sample should have negligible weight + for i in 0..100 { + tracker.record_value(100.0 + i as f64); + thread::sleep(Duration::from_micros(100)); + } + + let avg = tracker.get_decayed_time_weighted_average(); + + // Should be dominated by recent samples, not the old high value + assert!( + avg < 500.0, + "Very old samples should have negligible impact: {}", + avg + ); + assert!(avg > 100.0, "Average should still be reasonable: {}", avg); + } + + // Scenario 5: Monotonic behavior verification + { + const TIME_CONSTANT_SECS: f64 = 5.0; + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Add samples in strictly increasing order + let mut previous_avg = 0.0; + for i in 1..=1000 { + tracker.record_value(i as f64); + + if i % 100 == 0 { + let current_avg = tracker.get_decayed_time_weighted_average(); + + // Average should generally increase when adding larger values + assert!( + current_avg > previous_avg, + "Average should increase with larger values: {} -> {} at iteration {}", + previous_avg, + current_avg, + i + ); + + previous_avg = current_avg; + thread::sleep(Duration::from_millis(1)); + } + } + } + + // Scenario 6: Stability under rapid updates + { + const TIME_CONSTANT_SECS: f64 = 2.0; + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Rapidly add many samples without any delays + for i in 0..100_000 { + tracker.record_value((i % 10) as f64); // Values 0-9 + } + + let avg = tracker.get_decayed_time_weighted_average(); + + assert!(avg.is_finite(), "Rapid updates should maintain stability"); + assert!( + avg >= 0.0 && avg <= 9.0, + "Average should be bounded: {}", + avg + ); + + // Should be close to the mean of 0-9 = 4.5 + assert!( + (avg - 4.5).abs() < 1.0, + "Average should be close to sample mean with rapid updates: {}", + avg + ); + } + + // Scenario 7: Verify internal weight tracking remains stable + { + const TIME_CONSTANT_SECS: f64 = 1.0; + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Add samples over a long period to test weight accumulation + for i in 0..1000 { + tracker.record_value(50.0); // Constant value + + if i % 100 == 0 { + thread::sleep(Duration::from_millis(100)); + } + } + + // Internal weights should be reasonable (not infinite or zero) + let avg = tracker.get_decayed_time_weighted_average(); + assert!(avg.is_finite(), "Internal state should remain stable"); + assert!( + (avg - 50.0).abs() < 1e-4, + "Constant values should maintain constant average" + ); + + // Test that the tracker can still respond to new values + tracker.record_value(100.0); + let new_avg = tracker.get_decayed_time_weighted_average(); + assert!( + new_avg > 50.0, + "Tracker should still respond to new values after long series" + ); + } - Ok(()) + println!("✓ All numerical stability tests passed"); } } diff --git a/lib/llm/src/http/service/service_v2.rs b/lib/llm/src/http/service/service_v2.rs index 09ef548105..ff730ae8eb 100644 --- a/lib/llm/src/http/service/service_v2.rs +++ b/lib/llm/src/http/service/service_v2.rs @@ -152,9 +152,7 @@ impl HttpServiceConfigBuilder { let config: HttpServiceConfig = self.build_internal()?; let model_manager = Arc::new(ModelManager::new()); - let rate_limiter = Arc::new(RateLimiter::new( - config.rate_limiter_config.unwrap_or_default(), - )); + let rate_limiter = Arc::new(RateLimiter::new(config.rate_limiter_config)); let state = Arc::new(State::new(model_manager, rate_limiter)); // enable prometheus metrics From 58853672d97ba34df1086a6be3e26612922e5e55 Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Tue, 15 Jul 2025 17:04:23 +0100 Subject: [PATCH 04/11] improve code structure along the repository --- components/http/src/main.rs | 76 +++++++++++- docs/guides/rate_limiting.md | 117 ++++++++++++++++++ lib/llm/benches/rate_limiter.rs | 12 +- lib/llm/src/http/service/metrics.rs | 60 ++++++++++ lib/llm/src/http/service/openai.rs | 66 +++++++++-- lib/llm/src/http/service/rate_limiter.rs | 144 +++++++++++++++-------- lib/llm/src/http/service/service_v2.rs | 6 + lib/llm/tests/http-service.rs | 1 + 8 files changed, 409 insertions(+), 73 deletions(-) create mode 100644 docs/guides/rate_limiting.md diff --git a/components/http/src/main.rs b/components/http/src/main.rs index 7762c21f50..c94d906674 100644 --- a/components/http/src/main.rs +++ b/components/http/src/main.rs @@ -4,9 +4,10 @@ use clap::Parser; use dynamo_llm::discovery::{ModelWatcher, MODEL_ROOT_PATH}; +use dynamo_llm::http::service::rate_limiter::{RateLimiterConfig}; use dynamo_llm::http::service::service_v2::HttpService; use dynamo_runtime::{ - logging, pipeline::RouterMode, transports::etcd::PrefixWatcher, DistributedRuntime, Result, + logging, pipeline::RouterMode, transports::etcd::PrefixWatcher, DistributedRuntime, Result, error, Runtime, Worker, }; @@ -28,6 +29,42 @@ struct Args { /// Component name for the service #[arg(long, default_value = "http")] component: String, + + /// Enable rate limiting + #[arg(long, default_value = "false")] + enable_rate_limiting: bool, + + /// Time to first token threshold in milliseconds + #[arg( + long, + default_value = "1000.0", + help = "Desired time to first token threshold in milliseconds" + )] + ttft_threshold_ms: f64, + + /// Inter-token latency threshold in milliseconds + #[arg( + long, + default_value = "30.0", + help = "Desired inter-token latency threshold in milliseconds" + )] + itl_threshold_ms: f64, + + /// Time constant for the rate limiter in seconds + #[arg( + long, + default_value = "15.0", + help = "Time constant for the exponential moving average calculation in the rate limiter, in seconds" + )] + time_constant_secs: f64, + + /// Per model rate limiting + #[arg( + long, + default_value = "false", + help = "Track rate limits per model separately, instead of globally" + )] + per_model_rate_limiting: bool, } #[tokio::main] @@ -41,10 +78,21 @@ async fn app(runtime: Runtime) -> Result<()> { let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; let args = Args::parse(); - let http_service = HttpService::builder() - .port(args.port) - .host(args.host) - .build()?; + validate_args(&args)?; + + let mut http_service_builder = HttpService::builder().port(args.port).host(args.host); + + if args.enable_rate_limiting { + let rate_limiter = RateLimiterConfig::new( + args.ttft_threshold_ms / 1_000.0, + args.itl_threshold_ms / 1_000.0, + args.time_constant_secs, + args.per_model_rate_limiting, + )?; + http_service_builder = http_service_builder.with_rate_limiter(rate_limiter); + } + + let http_service = http_service_builder.build()?; let manager = http_service.state().manager_clone(); // todo - use the IntoComponent trait to register the component @@ -71,3 +119,21 @@ async fn app(runtime: Runtime) -> Result<()> { // Run the service http_service.run(runtime.child_token()).await } + +fn validate_args(args: &Args) -> Result<()> { + if args.enable_rate_limiting { + if args.ttft_threshold_ms <= 0.0 { + return Err(error!("Time to first token threshold must be greater than 0")); + } + + if args.itl_threshold_ms <= 0.0 { + return Err(error!("Inter-token latency threshold must be greater than 0")); + } + + if args.time_constant_secs <= 0.0 { + return Err(error!("Time constant must be greater than 0")); + } + } + + Ok(()) +} \ No newline at end of file diff --git a/docs/guides/rate_limiting.md b/docs/guides/rate_limiting.md new file mode 100644 index 0000000000..2944a9302c --- /dev/null +++ b/docs/guides/rate_limiting.md @@ -0,0 +1,117 @@ +# Rate Limiting Guide + +## Overview + +The Dynamo LLM service includes an intelligent rate limiter that monitors service performance metrics and automatically throttles requests when quality degrades. Unlike traditional rate limiters that count requests, this system focuses on maintaining good user experience by monitoring: + +- **Time to First Token (TTFT)** - How long users wait for the first response +- **Inter-Token Latency (ITL)** - How long between subsequent tokens + +## How It Works + +### Time-Weighted Exponential Moving Average + +The rate limiter uses a sophisticated time-weighted exponential moving average (EMA) algorithm: + +```text +average = sum(value * weight) / sum(weight) +weight = exp(-age / time_constant_secs) +``` + + +This means: +- Recent samples have higher influence on the average +- Old samples decay exponentially over time +- System "recovers" during idle periods + +### Decision Logic + +For each incoming request, the system: +1. Computes current decayed EMA for TTFT and ITL +2. Compares against configured thresholds +3. Rejects request if either threshold is exceeded +4. Logs detailed metrics for observability + +## Configuration + +### Environment Variables + +```bash +# Enable rate limiting +export DYN_RATE_LIMITER_ENABLED=true + +# TTFT threshold in milliseconds (default: 1000ms = 1s) +export DYN_RATE_LIMITER_TTFT_THRESHOLD_MS=1500 + +# ITL threshold in milliseconds (default: 10ms) +export DYN_RATE_LIMITER_ITL_THRESHOLD_MS=15 + +# Time constant for EMA decay (default: 30s) +export DYN_RATE_LIMITER_TIME_CONSTANT_SECS=60 + +# Enable per-model vs global limits (default: false) +export DYN_RATE_LIMITER_PER_MODEL_LIMITS=true +``` + +### Command Line Arguments + +```bash +dynamo-http \ + --enable-rate-limiting \ + --ttft-threshold-ms 1500 \ + --itl-threshold-ms 15 \ + --time-constant-secs 60 \ + --per-model-limits +``` + +### Programmatic Configuration + +```rust +use dynamo_llm::http::service::rate_limiter::RateLimiterConfig; + +let config = RateLimiterConfig::new( + 1500.0, // TTFT threshold (ms) + 15.0, // ITL threshold (ms) + 60.0, // Time constant (s) + true, // Per-model limits +); + +let http_service = HttpService::builder() + .with_rate_limiter_config(config) + .build()?; +``` + +## Monitoring + +### Prometheus Metrics + +The rate limiter exposes several Prometheus metrics: + +**Requests rejected by rate limiter:** + +```text +nv_llm_http_service_rate_limit_requests_total{model, endpoint, request_type, status} +``` + +**Current TTFT metrics:** + +```text +nv_llm_http_service_time_to_first_token_seconds{model} +``` + +**Current ITL metrics:** + +```text +nv_llm_http_service_inter_token_latency_seconds{model} +``` + +### Log Messages + +When requests are rejected, detailed log messages are emitted: + +```text +WARN Rate limit exceeded for model deepseek-ai/DeepSeek-R1: RateLimiterMetrics { +TTFT: TimeWeightedDiagnostics { decayed_time_weighted_average: 2.450, time_constant_secs: 30.0, last_weighted_sum: 1.245, duration_since_last_update: 0.125 }, +ITL: TimeWeightedDiagnostics { decayed_time_weighted_average: 0.025, time_constant_secs: 30.0, last_weighted_sum: 1.245, duration_since_last_update: 0.125 } +} +``` \ No newline at end of file diff --git a/lib/llm/benches/rate_limiter.rs b/lib/llm/benches/rate_limiter.rs index 0d6bacd7d3..6153f4a594 100644 --- a/lib/llm/benches/rate_limiter.rs +++ b/lib/llm/benches/rate_limiter.rs @@ -96,7 +96,7 @@ fn bench_time_constants(c: &mut Criterion) { fn bench_rate_limiter_decisions(c: &mut Criterion) { let mut group = c.benchmark_group("rate_limiter_decisions"); - let config = RateLimiterConfig::new(100.0, 10.0, 10.0, false); + let config = RateLimiterConfig::new(100.0, 10.0, 10.0, false).unwrap(); group.bench_function("should_reject_with_data", |b| { let rate_limiter = RateLimiter::new(Some(config.clone())); @@ -147,7 +147,7 @@ fn bench_concurrent_access(c: &mut Criterion) { &thread_count, |b, &num_threads| { b.iter(|| { - let config = RateLimiterConfig::new(1000.0, 10.0, 30.0, false); + let config = RateLimiterConfig::new(1000.0, 10.0, 30.0, false).unwrap(); let rate_limiter = Arc::new(RateLimiter::new(Some(config))); let handles: Vec<_> = (0..num_threads) @@ -203,7 +203,7 @@ fn bench_memory_patterns(c: &mut Criterion) { }); group.bench_function("per_model_isolation", |b| { - let config = RateLimiterConfig::new(1000.0, 10.0, 30.0, true); + let config = RateLimiterConfig::new(1000.0, 10.0, 30.0, true).unwrap(); b.iter(|| { let rate_limiter = RateLimiter::new(Some(config.clone())); @@ -287,15 +287,15 @@ fn bench_configuration_comparison(c: &mut Criterion) { let configs = vec![ ( "aggressive", - RateLimiterConfig::new(1000.0, 10.0, 1.0, false), + RateLimiterConfig::new(1000.0, 10.0, 1.0, false).unwrap(), ), ( "balanced", - RateLimiterConfig::new(1000.0, 10.0, 10.0, false), + RateLimiterConfig::new(1000.0, 10.0, 10.0, false).unwrap(), ), ( "conservative", - RateLimiterConfig::new(1000.0, 10.0, 60.0, false), + RateLimiterConfig::new(1000.0, 10.0, 60.0, false).unwrap(), ), ]; diff --git a/lib/llm/src/http/service/metrics.rs b/lib/llm/src/http/service/metrics.rs index 212364b042..9f1dfeab29 100644 --- a/lib/llm/src/http/service/metrics.rs +++ b/lib/llm/src/http/service/metrics.rs @@ -20,6 +20,9 @@ pub const REQUEST_STATUS_SUCCESS: &str = "success"; /// Value for the `status` label in the request counter if the request failed pub const REQUEST_STATUS_ERROR: &str = "error"; +/// Value for the `status` label in the request counter if the request was rejected by the rate limiter +pub const REQUEST_STATUS_REJECTED: &str = "rejected"; + /// Partial value for the `type` label in the request counter for streaming requests pub const REQUEST_TYPE_STREAM: &str = "stream"; @@ -35,6 +38,8 @@ pub struct Metrics { output_sequence_length: HistogramVec, time_to_first_token: HistogramVec, inter_token_latency: HistogramVec, + rate_limit_ema_ttft: HistogramVec, + rate_limit_ema_itl: HistogramVec, } /// RAII object for inflight gauge and request counters @@ -75,10 +80,21 @@ pub enum RequestType { Stream, } +impl RequestType { + pub fn from_streaming_boolean(is_streaming: bool) -> Self { + if is_streaming { + RequestType::Stream + } else { + RequestType::Unary + } + } +} + /// Status pub enum Status { Success, Error, + Rejected, } /// Track response-specific metrics @@ -202,6 +218,31 @@ impl Metrics { ) .unwrap(); + let rate_limit_ema_ttft = HistogramVec::new( + HistogramOpts::new( + format!("{}_http_service_rate_limit_ema_ttft_seconds", prefix), + "Time to first token in seconds", + ) + .buckets(vec![ + 0.0, 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, + 60.0, 120.0, 240.0, 480.0, + ]), + &["model"], + ) + .unwrap(); + + let rate_limit_ema_itl = HistogramVec::new( + HistogramOpts::new( + format!("{}_http_service_rate_limit_ema_itl_seconds", prefix), + "Inter-token latency in seconds", + ) + .buckets(vec![ + 0.0, 0.001, 0.005, 0.01, 0.015, 0.02, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.0, + ]), + &["model"], + ) + .unwrap(); + Metrics { rate_limit_requests_counter, request_counter, @@ -211,6 +252,8 @@ impl Metrics { output_sequence_length, time_to_first_token, inter_token_latency, + rate_limit_ema_ttft, + rate_limit_ema_itl, } } @@ -324,6 +367,8 @@ impl Metrics { registry.register(Box::new(self.output_sequence_length.clone()))?; registry.register(Box::new(self.time_to_first_token.clone()))?; registry.register(Box::new(self.inter_token_latency.clone()))?; + registry.register(Box::new(self.rate_limit_ema_ttft.clone()))?; + registry.register(Box::new(self.rate_limit_ema_itl.clone()))?; Ok(()) } @@ -360,6 +405,20 @@ impl Metrics { ) -> ResponseMetricCollector { ResponseMetricCollector::new(self, model.to_string().to_lowercase(), rate_limiter) } + + /// Record the time to first token for the given model, endpoint, and request type + pub fn record_rate_limit_ttft(&self, ttft_ema: f64, model: &str) { + self.rate_limit_ema_ttft + .with_label_values(&[model]) + .observe(ttft_ema); + } + + /// Record the inter-token latency for the given model, endpoint, and request type + pub fn record_rate_limit_itl(&self, itl_ema: f64, model: &str) { + self.rate_limit_ema_itl + .with_label_values(&[model]) + .observe(itl_ema); + } } impl InflightGuard { @@ -450,6 +509,7 @@ impl Status { match self { Status::Success => REQUEST_STATUS_SUCCESS, Status::Error => REQUEST_STATUS_ERROR, + Status::Rejected => REQUEST_STATUS_REJECTED, } } } diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index 48849688e9..7f7155ccd2 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -28,15 +28,21 @@ use super::{ metrics::{Endpoint, InflightGuard, ResponseMetricCollector}, service_v2, RouteDoc, }; -use crate::preprocessor::LLMMetricAnnotation; -use crate::protocols::openai::{ - chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionResponse}, - completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, - embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, - responses::{NvCreateResponse, NvResponse}, -}; use crate::request_template::RequestTemplate; use crate::types::Annotated; +use crate::{ + http::service::metrics::{RequestType, Status}, + preprocessor::LLMMetricAnnotation, +}; +use crate::{ + http::service::rate_limiter::ShouldRejectResult, + protocols::openai::{ + chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionResponse}, + completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, + embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, + responses::{NvCreateResponse, NvResponse}, + }, +}; #[derive(Serialize, Deserialize)] pub(crate) struct ErrorResponse { @@ -149,7 +155,12 @@ async fn completions( check_ready(&state)?; // Rate limit check - should_reject_request(&state, &request.inner.model)?; + should_reject_request( + &state, + &request.inner.model, + &Endpoint::Completions, + &RequestType::from_streaming_boolean(request.inner.stream.unwrap_or(false)), + )?; // todo - extract distributed tracing id and context id from headers let request_id = uuid::Uuid::new_v4().to_string(); @@ -308,7 +319,12 @@ async fn chat_completions( check_ready(&state)?; // Rate limit check - should_reject_request(&state, &request.inner.model)?; + should_reject_request( + &state, + &request.inner.model, + &Endpoint::ChatCompletions, + &RequestType::from_streaming_boolean(request.inner.stream.unwrap_or(false)), + )?; // Handle unsupported fields - if Some(resp) is returned by // validate_chat_completion_unsupported_fields, @@ -468,6 +484,15 @@ async fn responses( // return a 503 if the service is not ready check_ready(&state)?; + // Rate limit check + // TODO: handle streaming, currently just unary + should_reject_request( + &state, + &request.inner.model, + &Endpoint::Responses, + &RequestType::Unary, + )?; + // Handle unsupported fields - if Some(resp) is returned by validate_unsupported_fields, // then a field was used that is unsupported. We will log an error message // and early return a 501 NOT_IMPLEMENTED status code. Otherwise, proceeed. @@ -573,16 +598,35 @@ async fn responses( pub fn should_reject_request( state: &Arc, model: &str, + endpoint: &Endpoint, + request_type: &RequestType, ) -> Result<(), (StatusCode, Json)> { if !state.rate_limiter().is_enabled() { return Ok(()); } - let should_reject = state.rate_limiter().should_reject(model); + let ShouldRejectResult { + should_reject, + decayed_ttft_ema, + decayed_itl_ema, + } = state.rate_limiter().should_reject(model); + + state + .metrics_clone() + .record_rate_limit_ttft(decayed_ttft_ema, model); + state + .metrics_clone() + .record_rate_limit_itl(decayed_itl_ema, model); if should_reject { + state.metrics_clone().inc_rate_limit_requests_counter( + model, + endpoint, + request_type, + &Status::Rejected, + ); return Err(ErrorResponse::rate_limit_exceeded(&format!( - "Rate limit exceeded for current request and model: {model}. Please retry later." + "Too many requests: rate limit exceeded for current request and model: {model}. Please retry later." ))); } diff --git a/lib/llm/src/http/service/rate_limiter.rs b/lib/llm/src/http/service/rate_limiter.rs index 6ceb4cce42..b832aeeb16 100644 --- a/lib/llm/src/http/service/rate_limiter.rs +++ b/lib/llm/src/http/service/rate_limiter.rs @@ -38,6 +38,7 @@ use std::time::Instant; +use anyhow::Result; use dashmap::DashMap; use validator::Validate; @@ -46,10 +47,10 @@ use validator::Validate; pub struct RateLimiterConfig { /// Threshold for the time to first token metric #[validate(range(min = 0.0))] - ttft_threshold_ms: f64, + ttft_threshold_secs: f64, /// Threshold for the inter-token latency metric #[validate(range(min = 0.0))] - itl_threshold_ms: f64, + itl_threshold_secs: f64, /// Time constant for the time-weighted EMA #[validate(range(min = 0.001))] time_constant_secs: f64, @@ -59,23 +60,29 @@ pub struct RateLimiterConfig { impl RateLimiterConfig { pub fn new( - ttft_threshold_ms: f64, - itl_threshold_ms: f64, + ttft_threshold_secs: f64, + itl_threshold_secs: f64, time_constant_secs: f64, per_model_limits: bool, - ) -> Self { - Self { - ttft_threshold_ms, - itl_threshold_ms, + ) -> Result { + let config: RateLimiterConfig = Self { + ttft_threshold_secs, + itl_threshold_secs, time_constant_secs, per_model_limits, - } + }; + + config + .validate() + .map_err(|e| anyhow::anyhow!("Invalid rate limiter config: {}", e))?; + + Ok(config) } pub fn empty() -> Self { Self { - ttft_threshold_ms: 0.0, - itl_threshold_ms: 0.0, + ttft_threshold_secs: 0.0, + itl_threshold_secs: 0.0, time_constant_secs: 0.001, per_model_limits: false, } @@ -85,9 +92,9 @@ impl RateLimiterConfig { impl Default for RateLimiterConfig { fn default() -> Self { Self { - ttft_threshold_ms: 1000.0, // 1s - itl_threshold_ms: 10.0, // 10ms - time_constant_secs: 30.0, // 30s + ttft_threshold_secs: 1.0, // 1s + itl_threshold_secs: 0.1, // 100ms + time_constant_secs: 30.0, // 30s per_model_limits: false, } } @@ -292,12 +299,16 @@ impl RateLimiter { /// Check if the request should be rejected based on the cached metrics /// /// Returns true if the request should be rejected, false otherwise - pub fn should_reject(&self, model: &str) -> bool { + pub fn should_reject(&self, model: &str) -> ShouldRejectResult { let model_key = self.get_model_key(model); let model_metrics = self.model_metrics.get(&model_key); let Some(model_metrics) = model_metrics else { - return false; + return ShouldRejectResult { + should_reject: false, + decayed_ttft_ema: 0.0, + decayed_itl_ema: 0.0, + }; }; // Get decayed time-weighted EMA values @@ -310,16 +321,31 @@ impl RateLimiter { drop(model_metrics); - let ttft_exceeded = self.config.ttft_threshold_ms < decayed_ttft_ema; - let itl_exceeded = self.config.itl_threshold_ms < decayed_itl_ema; + let ttft_exceeded = self.config.ttft_threshold_secs < decayed_ttft_ema; + let itl_exceeded = self.config.itl_threshold_secs < decayed_itl_ema; if ttft_exceeded || itl_exceeded { let rate_limiter_metrics = self.get_metrics(&model_key); - self.log_metrics(model, rate_limiter_metrics); - return true; + self.log_metrics(model, rate_limiter_metrics, true); + return ShouldRejectResult { + should_reject: true, + decayed_ttft_ema, + decayed_itl_ema, + }; + } + + if decayed_ttft_ema > self.config.ttft_threshold_secs * 0.9 + || decayed_itl_ema > self.config.itl_threshold_secs * 0.9 + { + let rate_limiter_metrics = self.get_metrics(&model_key); + self.log_metrics(model, rate_limiter_metrics, false); } - false + ShouldRejectResult { + should_reject: false, + decayed_ttft_ema, + decayed_itl_ema, + } } /// Get current metrics and diagnostics for current model @@ -353,17 +379,33 @@ impl RateLimiter { } } - fn log_metrics(&self, model: &str, metrics: RateLimiterMetrics) { - tracing::warn!( - model = model, - ttft_threshold_ms = self.config.ttft_threshold_ms, - itl_threshold_ms = self.config.itl_threshold_ms, - "Rate limit exceeded for model {model}: {metrics}", - metrics = metrics, - ); + fn log_metrics(&self, model: &str, metrics: RateLimiterMetrics, has_exceeded: bool) { + if has_exceeded { + tracing::warn!( + model = model, + ttft_threshold_secs = self.config.ttft_threshold_secs, + itl_threshold_secs = self.config.itl_threshold_secs, + "Rate limit exceeded for model {model}: {metrics}", + metrics = metrics, + ); + } else { + tracing::info!( + model = model, + ttft_threshold_secs = self.config.ttft_threshold_secs, + itl_threshold_secs = self.config.itl_threshold_secs, + "Approaching rate limit thresholds. Current rate limit metrics for model {model}: {metrics}", + metrics = metrics, + ); + } } } +pub struct ShouldRejectResult { + pub should_reject: bool, + pub decayed_ttft_ema: f64, + pub decayed_itl_ema: f64, +} + #[cfg(test)] mod tests { use super::*; @@ -659,8 +701,8 @@ mod tests { const SLEEP_DURATION_MS: u64 = 1; let config = RateLimiterConfig { - ttft_threshold_ms: 1000.0, - itl_threshold_ms: 100.0, + ttft_threshold_secs: 1.0, + itl_threshold_secs: 0.1, time_constant_secs: 30.0, per_model_limits: false, }; @@ -675,10 +717,10 @@ mod tests { handles.push(thread::spawn(move || { for j in 0..NUM_RECORDS { - limiter_clone.record_ttft("model", (i * NUM_RECORDS + j) as f64); - limiter_clone.record_itl("model", (i + j) as f64); + limiter_clone.record_ttft("model", (i * NUM_RECORDS + j) as f64 / 10_000.0); + limiter_clone.record_itl("model", (i + j) as f64 / 1_000.0); - if limiter_clone.should_reject("model") { + if limiter_clone.should_reject("model").should_reject { error_count_clone.fetch_add(1, Ordering::Relaxed); } @@ -707,8 +749,8 @@ mod tests { const SLEEP_DURATION_MS: u64 = 1; let config = RateLimiterConfig { - ttft_threshold_ms: 1000.0, - itl_threshold_ms: 10.0, + ttft_threshold_secs: 1.0, + itl_threshold_secs: 0.1, time_constant_secs: 30.0, per_model_limits: false, }; @@ -723,10 +765,10 @@ mod tests { handles.push(thread::spawn(move || { for j in 0..NUM_RECORDS { - limiter_clone.record_ttft("model", (i * NUM_RECORDS + j) as f64); - limiter_clone.record_itl("model", (i + j) as f64); + limiter_clone.record_ttft("model", (i * NUM_RECORDS + j) as f64 / 1_000.0); + limiter_clone.record_itl("model", (i + j) as f64 / 100.0); - if limiter_clone.should_reject("model") { + if limiter_clone.should_reject("model").should_reject { error_count_clone.fetch_add(1, Ordering::Relaxed); } @@ -810,8 +852,8 @@ mod tests { #[test] fn test_rate_limiter_integration() { let config = RateLimiterConfig { - ttft_threshold_ms: 100.0, // 100ms - itl_threshold_ms: 5.0, // 5ms + ttft_threshold_secs: 100., // 100ms + itl_threshold_secs: 1., // 5ms time_constant_secs: 1.0, ..Default::default() }; @@ -826,7 +868,7 @@ mod tests { thread::sleep(Duration::from_millis(150)); // Wait for warmup assert!( - !limiter.should_reject("test"), + !limiter.should_reject("test").should_reject, "Should not reject with low values" ); @@ -835,7 +877,7 @@ mod tests { limiter.record_ttft("test", 300.0); assert!( - limiter.should_reject("test"), + limiter.should_reject("test").should_reject, "Should reject with high values" ); } @@ -845,8 +887,8 @@ mod tests { const NUM_SAMPLES: usize = 100; let config = RateLimiterConfig { - ttft_threshold_ms: 100.0, // 100ms - itl_threshold_ms: 5.0, // 5ms + ttft_threshold_secs: 70., + itl_threshold_secs: 0.005, time_constant_secs: 1.0, ..Default::default() }; @@ -861,7 +903,7 @@ mod tests { thread::sleep(Duration::from_millis(150)); // Wait for warmup assert!( - !limiter.should_reject("test"), + !limiter.should_reject("test").should_reject, "Should not reject with low values" ); @@ -871,7 +913,7 @@ mod tests { } assert!( - limiter.should_reject("test"), + limiter.should_reject("test").should_reject, "Should reject with high values" ); } @@ -904,14 +946,14 @@ mod tests { thread::sleep(Duration::from_millis(20)); // Both should reject model A - assert!(global_limiter.should_reject(MODEL_A)); - assert!(per_model_limiter.should_reject(MODEL_A)); + assert!(global_limiter.should_reject(MODEL_A).should_reject); + assert!(per_model_limiter.should_reject(MODEL_A).should_reject); // Global limiter should also reject model B (uses same "global" key) - assert!(global_limiter.should_reject(MODEL_B)); + assert!(global_limiter.should_reject(MODEL_B).should_reject); // Per-model limiter should NOT reject model B (separate tracking) - assert!(!per_model_limiter.should_reject(MODEL_B)); + assert!(!per_model_limiter.should_reject(MODEL_B).should_reject); } #[test] diff --git a/lib/llm/src/http/service/service_v2.rs b/lib/llm/src/http/service/service_v2.rs index ff730ae8eb..478faa9c6c 100644 --- a/lib/llm/src/http/service/service_v2.rs +++ b/lib/llm/src/http/service/service_v2.rs @@ -148,6 +148,12 @@ impl HttpService { } impl HttpServiceConfigBuilder { + /// Set the rate limiter config for the HTTP service. + pub fn with_rate_limiter(mut self, config: RateLimiterConfig) -> Self { + self.rate_limiter_config = Some(Some(config)); + self + } + pub fn build(self) -> Result { let config: HttpServiceConfig = self.build_internal()?; diff --git a/lib/llm/tests/http-service.rs b/lib/llm/tests/http-service.rs index 1a9b850277..019f0a6a61 100644 --- a/lib/llm/tests/http-service.rs +++ b/lib/llm/tests/http-service.rs @@ -154,6 +154,7 @@ fn compute_index(endpoint: &Endpoint, request_type: &RequestType, status: &Statu let status = match status { Status::Success => 0, Status::Error => 1, + Status::Rejected => 2, }; endpoint * 4 + request_type * 2 + status From 46868ef32ef724c859eaf5f81e98342b585f1b45 Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Tue, 15 Jul 2025 17:14:04 +0100 Subject: [PATCH 05/11] pybindings integration --- lib/bindings/python/rust/http.rs | 39 ++++++++++++++++++++++-- lib/bindings/python/rust/lib.rs | 1 + lib/bindings/python/src/dynamo/_core.pyi | 15 ++++++++- 3 files changed, 52 insertions(+), 3 deletions(-) diff --git a/lib/bindings/python/rust/http.rs b/lib/bindings/python/rust/http.rs index 3a22092334..5a98f8c5f8 100644 --- a/lib/bindings/python/rust/http.rs +++ b/lib/bindings/python/rust/http.rs @@ -37,8 +37,16 @@ pub struct HttpService { impl HttpService { #[new] #[pyo3(signature = (port=None))] - pub fn new(port: Option) -> PyResult { - let builder = service_v2::HttpService::builder().port(port.unwrap_or(8080)); + pub fn new( + port: Option, + rate_limiter_config: Option, + ) -> PyResult { + let mut builder = service_v2::HttpService::builder().port(port.unwrap_or(8080)); + + if let Some(rate_limiter_config) = rate_limiter_config { + builder = builder.with_rate_limiter(rate_limiter_config.inner); + } + let inner = builder.build().map_err(to_pyerr)?; Ok(Self { inner }) } @@ -184,3 +192,30 @@ where } } } + +#[pyclass] +#[derive(Clone)] +pub struct RateLimiterConfig { + inner: dynamo_llm::http::service::rate_limiter::RateLimiterConfig, +} + +#[pymethods] +impl RateLimiterConfig { + #[new] + pub fn new( + ttft_threshold_secs: f64, + itl_threshold_secs: f64, + time_constant_secs: f64, + per_model_rate_limiting: bool, + ) -> PyResult { + let inner = dynamo_llm::http::service::rate_limiter::RateLimiterConfig::new( + ttft_threshold_secs, + itl_threshold_secs, + time_constant_secs, + per_model_rate_limiting, + ) + .map_err(to_pyerr)?; + + Ok(Self { inner }) + } +} diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index 7754a35e45..ab78f513fb 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -98,6 +98,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/lib/bindings/python/src/dynamo/_core.pyi b/lib/bindings/python/src/dynamo/_core.pyi index c176f95521..3117d7cbca 100644 --- a/lib/bindings/python/src/dynamo/_core.pyi +++ b/lib/bindings/python/src/dynamo/_core.pyi @@ -807,7 +807,7 @@ class HttpService: It is a OpenAI compatible http ingress into the Dynamo Distributed Runtime. """ - ... + ... class HttpError: """ @@ -816,6 +816,19 @@ class HttpError: ... +class RateLimiterConfig: + """ + A configuration for the HTTP service rate limiter logic + """ + + def __init__( + self, + ttft_threshold_secs: float, + itl_threshold_secs: float, + time_constant_secs: float, + per_model_limits: bool = False + ) -> None: ... + class HttpAsyncEngine: """ An async engine for a distributed Dynamo http service. This is an extension of the From 5b56708f3f87e20207fde641237a6beb75c9ac4e Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Tue, 15 Jul 2025 17:17:20 +0100 Subject: [PATCH 06/11] small refactors --- lib/bindings/python/rust/http.rs | 4 ++-- lib/bindings/python/src/dynamo/_core.pyi | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/lib/bindings/python/rust/http.rs b/lib/bindings/python/rust/http.rs index 5a98f8c5f8..bd0aa3d8f9 100644 --- a/lib/bindings/python/rust/http.rs +++ b/lib/bindings/python/rust/http.rs @@ -206,13 +206,13 @@ impl RateLimiterConfig { ttft_threshold_secs: f64, itl_threshold_secs: f64, time_constant_secs: f64, - per_model_rate_limiting: bool, + per_model_limits: bool, ) -> PyResult { let inner = dynamo_llm::http::service::rate_limiter::RateLimiterConfig::new( ttft_threshold_secs, itl_threshold_secs, time_constant_secs, - per_model_rate_limiting, + per_model_limits, ) .map_err(to_pyerr)?; diff --git a/lib/bindings/python/src/dynamo/_core.pyi b/lib/bindings/python/src/dynamo/_core.pyi index 3117d7cbca..6331ee21ad 100644 --- a/lib/bindings/python/src/dynamo/_core.pyi +++ b/lib/bindings/python/src/dynamo/_core.pyi @@ -807,7 +807,11 @@ class HttpService: It is a OpenAI compatible http ingress into the Dynamo Distributed Runtime. """ - ... + def __init__( + self, + port: Optional[int] = None, + rate_limiter_config: Optional[RateLimiterConfig] = None + ) -> None: ... class HttpError: """ From 182c6153179021bfffc1b4812792e9116495df12 Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Tue, 15 Jul 2025 18:19:30 +0100 Subject: [PATCH 07/11] add integration tests --- docs/guides/rate_limiting.md | 50 ++- lib/llm/src/http/service/metrics.rs | 3 + lib/llm/tests/http-service.rs | 622 ++++++++++++++++++++++++++++ 3 files changed, 674 insertions(+), 1 deletion(-) diff --git a/docs/guides/rate_limiting.md b/docs/guides/rate_limiting.md index 2944a9302c..acca3a9a14 100644 --- a/docs/guides/rate_limiting.md +++ b/docs/guides/rate_limiting.md @@ -114,4 +114,52 @@ WARN Rate limit exceeded for model deepseek-ai/DeepSeek-R1: RateLimiterMetrics { TTFT: TimeWeightedDiagnostics { decayed_time_weighted_average: 2.450, time_constant_secs: 30.0, last_weighted_sum: 1.245, duration_since_last_update: 0.125 }, ITL: TimeWeightedDiagnostics { decayed_time_weighted_average: 0.025, time_constant_secs: 30.0, last_weighted_sum: 1.245, duration_since_last_update: 0.125 } } -``` \ No newline at end of file +``` + + +## Tuning Guidelines + +### Time Constant +- **Shorter (10-30s)**: Faster reaction to load changes, more sensitive +- **Longer (60-120s)**: Smoother operation, less reactive to spikes + +### TTFT Threshold +- **Conservative (500-1000ms)**: Maintains very responsive feel +- **Moderate (1000-2000ms)**: Balances throughput with responsiveness +- **Aggressive (2000ms+)**: Prioritizes throughput over latency + +### ITL Threshold +- **Conservative (5-10ms)**: Ensures smooth streaming experience +- **Moderate (10-20ms)**: Allows some latency for higher throughput +- **Aggressive (20ms+)**: Accepts choppier streaming for max throughput + +### Per-Model vs Global +- **Per-Model**: Better for multi-tenant scenarios with different SLAs +- **Global**: Simpler for single-tenant or uniform SLA scenarios + +## Best Practices + +1. **Start Conservative**: Begin with lower thresholds and increase based on user feedback +2. **Monitor Closely**: Watch both rate limit counters and user-facing metrics +3. **Load Test**: Validate behavior under realistic load patterns +4. **Document SLAs**: Clearly communicate expected performance to users +5. **Alert on Rejections**: Set up alerts when rejection rates exceed acceptable levels + +## Troubleshooting + +### High Rejection Rates +- Check if system is genuinely overloaded +- Consider increasing thresholds temporarily +- Scale backend resources +- Investigate specific models causing issues + +### No Rejections During Overload +- Verify rate limiter is enabled +- Check threshold configuration +- Ensure metrics are being recorded properly +- Review time constant settings + +### Inconsistent Behavior +- Check if per-model limits are configured correctly +- Review metric collection for gaps +- Validate system clock stability \ No newline at end of file diff --git a/lib/llm/src/http/service/metrics.rs b/lib/llm/src/http/service/metrics.rs index 9f1dfeab29..a0ce4654a4 100644 --- a/lib/llm/src/http/service/metrics.rs +++ b/lib/llm/src/http/service/metrics.rs @@ -128,6 +128,9 @@ impl Metrics { /// - `{prefix}_http_service_output_sequence_tokens` - HistogramVec for output sequence length in tokens /// - `{prefix}_http_service_time_to_first_token_seconds` - HistogramVec for time to first token in seconds /// - `{prefix}_http_service_inter_token_latency_seconds` - HistogramVec for inter-token latency in seconds + /// - `{prefix}_http_service_rate_limit_requests_total` - IntCounterVec for the total number of requests rejected by the rate limiter + /// - `{prefix}_http_service_rate_limit_ema_ttft_seconds` - HistogramVec for time to first token in seconds + /// - `{prefix}_http_service_rate_limit_ema_itl_seconds` - HistogramVec for inter-token latency in seconds pub fn new(prefix: &str) -> Self { let rate_limit_requests_counter = IntCounterVec::new( Opts::new( diff --git a/lib/llm/tests/http-service.rs b/lib/llm/tests/http-service.rs index 019f0a6a61..fb8b15c9fd 100644 --- a/lib/llm/tests/http-service.rs +++ b/lib/llm/tests/http-service.rs @@ -18,6 +18,7 @@ use async_stream::stream; use dynamo_llm::http::service::{ error::HttpError, metrics::{Endpoint, RequestType, Status}, + rate_limiter::RateLimiterConfig, service_v2::HttpService, Metrics, }; @@ -471,3 +472,624 @@ async fn test_http_service() { cancel_token.cancel(); task.await.unwrap().unwrap(); } + +/// Engine that simulates low TTFT to trigger rate limiting +struct SlowTTFTEngine { + ttft_delay_ms: u64, +} + +#[async_trait] +impl + AsyncEngine< + SingleIn, + ManyOut>, + Error, + > for SlowTTFTEngine +{ + async fn generate( + &self, + request: SingleIn, + ) -> Result>, Error> { + let (request, context) = request.transfer(()); + let ctx = context.context(); + + let generator = request.response_generator(); + let ttft_delay_ms = self.ttft_delay_ms; + + let stream = stream! { + // Simulate slow TTFT + tokio::time::sleep(std::time::Duration::from_millis(ttft_delay_ms)).await; + + // Generate a few tokens with normal ITL + for i in 0..3 { + let inner = generator.create_choice(i, Some(format!("token {i}")), None, None); + let output = NvCreateChatCompletionStreamResponse { inner }; + yield Annotated::from_data(output); + + if i < 2 { + tokio::time::sleep(std::time::Duration::from_millis(10)).await; // Normal ITL + } + } + }; + + Ok(ResponseStream::new(Box::pin(stream), ctx)) + } +} + +/// Engine that simulates slow ITL to trigger rate limiting +struct SlowITLEngine { + itl_delay_ms: u64, +} + +#[async_trait] +impl + AsyncEngine< + SingleIn, + ManyOut>, + Error, + > for SlowITLEngine +{ + async fn generate( + &self, + request: SingleIn, + ) -> Result>, Error> { + let (request, context) = request.transfer(()); + let ctx = context.context(); + + let generator = request.response_generator(); + let itl_delay_ms = self.itl_delay_ms; + + let stream = stream! { + // Fast TTFT + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + + // Generate tokens with slow ITL + for i in 0..5 { + let inner = generator.create_choice(i, Some(format!("token {i}")), None, None); + let output = NvCreateChatCompletionStreamResponse { inner }; + yield Annotated::from_data(output); + + if i < 4 { + tokio::time::sleep(std::time::Duration::from_millis(itl_delay_ms)).await; // Slow ITL + } + } + }; + + Ok(ResponseStream::new(Box::pin(stream), ctx)) + } +} + +#[tokio::test] +async fn test_rate_limiting_triggers_correctly() { + // Create rate limiter config with low thresholds for testing + let rate_limiter_config = RateLimiterConfig::new( + 1.0, // TTFT threshold: 1 second + 0.1, // ITL threshold: 100ms + 5.0, // Time constant: 5 seconds + false, // Global rate limiting + ) + .unwrap(); + + let service = HttpService::builder() + .port(8990) + .with_rate_limiter(rate_limiter_config) + .build() + .unwrap(); + + let state = service.state_clone(); + let manager = state.manager(); + + let token = CancellationToken::new(); + let cancel_token = token.clone(); + let task = tokio::spawn(async move { service.run(token.clone()).await }); + + // Add engines with different performance characteristics + let fast_engine = Arc::new(CounterEngine {}); + let slow_ttft_engine = Arc::new(SlowTTFTEngine { + ttft_delay_ms: 1500, + }); // 1.5s TTFT + let slow_itl_engine = Arc::new(SlowITLEngine { itl_delay_ms: 200 }); // 200ms ITL + + manager + .add_chat_completions_model("fast", fast_engine) + .unwrap(); + manager + .add_chat_completions_model("slow_ttft", slow_ttft_engine) + .unwrap(); + manager + .add_chat_completions_model("slow_itl", slow_itl_engine) + .unwrap(); + + let client = reqwest::Client::new(); + let metrics = state.metrics_clone(); + + // Wait for service to be ready + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + // Test 1: Fast model should work fine initially + let request = async_openai::types::CreateChatCompletionRequestArgs::default() + .model("fast") + .messages(vec![ + async_openai::types::ChatCompletionRequestMessage::User( + async_openai::types::ChatCompletionRequestUserMessage { + content: async_openai::types::ChatCompletionRequestUserMessageContent::Text( + "test".to_string(), + ), + name: None, + }, + ), + ]) + .stream(true) + .max_tokens(10 as u32) + .build() + .unwrap(); + + let response = client + .post("http://localhost:8990/v1/chat/completions") + .json(&request) + .send() + .await + .unwrap(); + + assert!( + response.status().is_success(), + "Fast model should work initially" + ); + let _ = response.bytes().await.unwrap(); + + // Test 2: Slow TTFT model should trigger rate limiting after a few requests + let mut slow_ttft_request = request.clone(); + slow_ttft_request.model = "slow_ttft".to_string(); + + // Make several requests to build up the EMA + for i in 0..3 { + println!("Sending slow TTFT request {}", i + 1); + let response = client + .post("http://localhost:8990/v1/chat/completions") + .json(&slow_ttft_request) + .send() + .await + .unwrap(); + + // First few requests should succeed (building up EMA) + if i < 2 { + assert!( + response.status().is_success(), + "Slow TTFT request {} should succeed while building EMA", + i + 1 + ); + let _ = response.bytes().await.unwrap(); + } else { + // Later requests should be rate limited + if response.status() == StatusCode::TOO_MANY_REQUESTS { + println!("Rate limiting triggered after {} requests", i + 1); + break; + } else { + let _ = response.bytes().await.unwrap(); + } + } + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; // Small delay between requests + } + + // Test 3: Slow ITL model should also trigger rate limiting + let mut slow_itl_request = request.clone(); + slow_itl_request.model = "slow_itl".to_string(); + + for i in 0..3 { + println!("Sending slow ITL request {}", i + 1); + let response = client + .post("http://localhost:8990/v1/chat/completions") + .json(&slow_itl_request) + .send() + .await + .unwrap(); + + if i < 2 { + assert!( + response.status().is_success(), + "Slow ITL request {} should succeed while building EMA", + i + 1 + ); + let _ = response.bytes().await.unwrap(); + } else { + // Later requests should be rate limited + if response.status() == StatusCode::TOO_MANY_REQUESTS { + println!("ITL rate limiting triggered after {} requests", i + 1); + break; + } else { + let _ = response.bytes().await.unwrap(); + } + } + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + + // Test 4: Verify rejection metrics were recorded + let rejection_count = metrics.get_rate_limit_requests_counter( + "slow_ttft", + &Endpoint::ChatCompletions, + &RequestType::Stream, + &Status::Rejected, + ) + metrics.get_rate_limit_requests_counter( + "slow_itl", + &Endpoint::ChatCompletions, + &RequestType::Stream, + &Status::Rejected, + ); + + println!("Total rejection count: {}", rejection_count); + + cancel_token.cancel(); + task.await.unwrap().unwrap(); +} + +#[tokio::test] +async fn test_rate_limiting_http_integration() { + let rate_limiter_config = RateLimiterConfig::new(0.1, 0.01, 5.0, false).unwrap(); + let service = HttpService::builder() + .port(8991) + .with_rate_limiter(rate_limiter_config) + .build() + .unwrap(); + + let state = service.state_clone(); + let manager = state.manager(); + + let token = CancellationToken::new(); + let cancel_token = token.clone(); + let task = tokio::spawn(async move { service.run(token.clone()).await }); + + // Use simple CounterEngine (already exists) + let engine = Arc::new(CounterEngine {}); + manager.add_chat_completions_model("test", engine).unwrap(); + + // Manually record high TTFT values to trigger rate limiting + state.rate_limiter().record_ttft("test", 0.5); + state.rate_limiter().record_ttft("test", 0.3); + state.rate_limiter().record_ttft("test", 0.4); + + let client = reqwest::Client::new(); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let request = async_openai::types::CreateChatCompletionRequestArgs::default() + .model("test") + .messages(vec![ + async_openai::types::ChatCompletionRequestMessage::User( + async_openai::types::ChatCompletionRequestUserMessage { + content: async_openai::types::ChatCompletionRequestUserMessageContent::Text( + "test".to_string(), + ), + name: None, + }, + ), + ]) + .stream(true) + .max_tokens(3 as u32) + .build() + .unwrap(); + + // This request should be rate limited + let response = client + .post("http://localhost:8991/v1/chat/completions") + .json(&request) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS); + println!("✅ Rate limiting triggered correctly!"); + + // Verify metrics were recorded + let rejection_count = state.metrics_clone().get_rate_limit_requests_counter( + "test", + &Endpoint::ChatCompletions, + &RequestType::Stream, + &Status::Rejected, + ); + assert!(rejection_count > 0, "Should have recorded rejection"); + + cancel_token.cancel(); + task.await.unwrap().unwrap(); +} + +#[tokio::test] +async fn test_per_model_vs_global_rate_limiting() { + // Test global rate limiting (per_model_limits = false) + let global_config = RateLimiterConfig::new(0.8, 0.08, 3.0, false).unwrap(); + let service1 = HttpService::builder() + .port(8992) + .with_rate_limiter(global_config) + .build() + .unwrap(); + + let state1 = service1.state_clone(); + let manager1 = state1.manager(); + + let token1 = CancellationToken::new(); + let cancel_token1 = token1.clone(); + let task1 = tokio::spawn(async move { service1.run(token1.clone()).await }); + + // Test per-model rate limiting (per_model_limits = true) + let per_model_config = RateLimiterConfig::new(0.8, 0.08, 3.0, true).unwrap(); + let service2 = HttpService::builder() + .port(8993) + .with_rate_limiter(per_model_config) + .build() + .unwrap(); + + let state2 = service2.state_clone(); + let manager2 = state2.manager(); + + let token2 = CancellationToken::new(); + let cancel_token2 = token2.clone(); + let task2 = tokio::spawn(async move { service2.run(token2.clone()).await }); + + // Add slow engines to both services + let slow_engine1a = Arc::new(SlowTTFTEngine { + ttft_delay_ms: 1200, + }); + let slow_engine1b = Arc::new(SlowTTFTEngine { + ttft_delay_ms: 1200, + }); + let slow_engine2a = Arc::new(SlowTTFTEngine { + ttft_delay_ms: 1200, + }); + let slow_engine2b = Arc::new(SlowTTFTEngine { + ttft_delay_ms: 1200, + }); + + manager1 + .add_chat_completions_model("model_a", slow_engine1a) + .unwrap(); + manager1 + .add_chat_completions_model("model_b", slow_engine1b) + .unwrap(); + manager2 + .add_chat_completions_model("model_a", slow_engine2a) + .unwrap(); + manager2 + .add_chat_completions_model("model_b", slow_engine2b) + .unwrap(); + + let client = reqwest::Client::new(); + + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + + let base_request = async_openai::types::CreateChatCompletionRequestArgs::default() + .messages(vec![ + async_openai::types::ChatCompletionRequestMessage::User( + async_openai::types::ChatCompletionRequestUserMessage { + content: async_openai::types::ChatCompletionRequestUserMessageContent::Text( + "test".to_string(), + ), + name: None, + }, + ), + ]) + .stream(true) + .max_tokens(10 as u32) + .build() + .unwrap(); + + // Test global rate limiting - model_a affects model_b + println!("Testing global rate limiting..."); + for _i in 0..3 { + let mut request_a = base_request.clone(); + request_a.model = "model_a".to_string(); + + let response = client + .post("http://localhost:8992/v1/chat/completions") + .json(&request_a) + .send() + .await + .unwrap(); + + if response.status().is_success() { + let _ = response.bytes().await.unwrap(); + } + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + + // Now model_b should be affected by model_a's rate limiting (global) + let mut request_b = base_request.clone(); + request_b.model = "model_b".to_string(); + + let response = client + .post("http://localhost:8992/v1/chat/completions") + .json(&request_b) + .send() + .await + .unwrap(); + + println!( + "Global rate limiting - model_b status: {}", + response.status() + ); + + // Test per-model rate limiting - model_a doesn't affect model_b + println!("Testing per-model rate limiting..."); + for _i in 0..3 { + let mut request_a = base_request.clone(); + request_a.model = "model_a".to_string(); + + let response = client + .post("http://localhost:8993/v1/chat/completions") + .json(&request_a) + .send() + .await + .unwrap(); + + if response.status().is_success() { + let _ = response.bytes().await.unwrap(); + } + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + + // model_b should NOT be affected by model_a's rate limiting (per-model) + let mut request_b2 = base_request.clone(); + request_b2.model = "model_b".to_string(); + + let response = client + .post("http://localhost:8993/v1/chat/completions") + .json(&request_b2) + .send() + .await + .unwrap(); + + println!( + "Per-model rate limiting - model_b status: {}", + response.status() + ); + // Model B should succeed since it has its own rate limiting state + assert!( + response.status().is_success() || response.status() != StatusCode::TOO_MANY_REQUESTS, + "Per-model rate limiting should not affect model_b" + ); + + if response.status().is_success() { + let _ = response.bytes().await.unwrap(); + } + + cancel_token1.cancel(); + cancel_token2.cancel(); + task1.await.unwrap().unwrap(); + task2.await.unwrap().unwrap(); +} + +#[tokio::test] +async fn test_rate_limiting_recovery() { + let rate_limiter_config = RateLimiterConfig::new( + 0.6, // TTFT threshold: 600ms + 0.06, // ITL threshold: 60ms + 1.0, // Short time constant for faster recovery + false, + ) + .unwrap(); + + let service = HttpService::builder() + .port(8994) + .with_rate_limiter(rate_limiter_config) + .build() + .unwrap(); + + let state = service.state_clone(); + let manager = state.manager(); + + let token = CancellationToken::new(); + let cancel_token = token.clone(); + let task = tokio::spawn(async move { service.run(token.clone()).await }); + + // Add engines with different speeds + let slow_engine = Arc::new(SlowTTFTEngine { + ttft_delay_ms: 1000, + }); // 1s TTFT + let fast_engine = Arc::new(CounterEngine {}); + + manager + .add_chat_completions_model("slow", slow_engine) + .unwrap(); + manager + .add_chat_completions_model("fast", fast_engine) + .unwrap(); + + let client = reqwest::Client::new(); + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let slow_request = async_openai::types::CreateChatCompletionRequestArgs::default() + .model("slow") + .messages(vec![ + async_openai::types::ChatCompletionRequestMessage::User( + async_openai::types::ChatCompletionRequestUserMessage { + content: async_openai::types::ChatCompletionRequestUserMessageContent::Text( + "test".to_string(), + ), + name: None, + }, + ), + ]) + .stream(true) + .max_tokens(10 as u32) + .build() + .unwrap(); + + let fast_request = async_openai::types::CreateChatCompletionRequestArgs::default() + .model("fast") + .messages(vec![ + async_openai::types::ChatCompletionRequestMessage::User( + async_openai::types::ChatCompletionRequestUserMessage { + content: async_openai::types::ChatCompletionRequestUserMessageContent::Text( + "test".to_string(), + ), + name: None, + }, + ), + ]) + .stream(true) + .max_tokens(10 as u32) + .build() + .unwrap(); + + // Phase 1: Trigger rate limiting with slow requests + println!("Phase 1: Triggering rate limiting..."); + for i in 0..4 { + let response = client + .post("http://localhost:8994/v1/chat/completions") + .json(&slow_request) + .send() + .await + .unwrap(); + + if response.status() == StatusCode::TOO_MANY_REQUESTS { + println!("Rate limiting triggered at request {}", i + 1); + break; + } else { + let _ = response.bytes().await.unwrap(); + } + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + + // Phase 2: Wait for system to recover (time constant is 1.0s) + println!("Phase 2: Waiting for recovery..."); + tokio::time::sleep(std::time::Duration::from_millis(3000)).await; // Wait 3 time constants + + // Phase 3: Send fast requests to bring down the EMA + println!("Phase 3: Sending fast requests to improve EMA..."); + for i in 0..3 { + let response = client + .post("http://localhost:8994/v1/chat/completions") + .json(&fast_request) + .send() + .await + .unwrap(); + + println!("Fast request {} status: {}", i + 1, response.status()); + if response.status().is_success() { + let _ = response.bytes().await.unwrap(); + } + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + } + + // Phase 4: Verify that slow requests work again (system recovered) + println!("Phase 4: Testing recovery with moderate request..."); + let response = client + .post("http://localhost:8994/v1/chat/completions") + .json(&fast_request) + .send() + .await + .unwrap(); + + println!("Recovery test status: {}", response.status()); + assert!( + response.status().is_success(), + "System should have recovered and accept requests again" + ); + + if response.status().is_success() { + let _ = response.bytes().await.unwrap(); + } + + cancel_token.cancel(); + task.await.unwrap().unwrap(); +} From 64033f57d4c9c6353dac707f210fca020ccf6dfb Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Sun, 20 Jul 2025 12:20:38 +0100 Subject: [PATCH 08/11] clippy checks --- lib/llm/tests/http-service.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/llm/tests/http-service.rs b/lib/llm/tests/http-service.rs index 241d7349ae..b3a348ae7e 100644 --- a/lib/llm/tests/http-service.rs +++ b/lib/llm/tests/http-service.rs @@ -1411,7 +1411,7 @@ async fn test_rate_limiting_triggers_correctly() { ), ]) .stream(true) - .max_tokens(10 as u32) + .max_tokens(10_u32) .build() .unwrap(); @@ -1659,7 +1659,7 @@ async fn test_per_model_vs_global_rate_limiting() { ), ]) .stream(true) - .max_tokens(10 as u32) + .max_tokens(10_u32) .build() .unwrap(); @@ -1801,7 +1801,7 @@ async fn test_rate_limiting_recovery() { ), ]) .stream(true) - .max_tokens(10 as u32) + .max_tokens(10_u32) .build() .unwrap(); @@ -1818,7 +1818,7 @@ async fn test_rate_limiting_recovery() { ), ]) .stream(true) - .max_tokens(10 as u32) + .max_tokens(10_u32) .build() .unwrap(); From 151e702bea45b161bdd6158790cdcb0a17165f1a Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Sun, 20 Jul 2025 13:00:40 +0100 Subject: [PATCH 09/11] clippy checks --- lib/llm/tests/http-service.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/llm/tests/http-service.rs b/lib/llm/tests/http-service.rs index b3a348ae7e..d475ad4217 100644 --- a/lib/llm/tests/http-service.rs +++ b/lib/llm/tests/http-service.rs @@ -1556,7 +1556,7 @@ async fn test_rate_limiting_http_integration() { ), ]) .stream(true) - .max_tokens(3 as u32) + .max_tokens(3_u32) .build() .unwrap(); From 32848d6556ef538dbd1ab602a1f5f9d12b6b8dd6 Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Sun, 20 Jul 2025 13:20:37 +0100 Subject: [PATCH 10/11] clippy checks --- lib/llm/src/http/service/rate_limiter.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/llm/src/http/service/rate_limiter.rs b/lib/llm/src/http/service/rate_limiter.rs index 3666578eb5..be3be6703b 100644 --- a/lib/llm/src/http/service/rate_limiter.rs +++ b/lib/llm/src/http/service/rate_limiter.rs @@ -544,8 +544,8 @@ mod tests { let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); // Record samples with known values and controlled timing - let sample_values = vec![100.0, 200.0, 300.0, 400.0]; - let sample_delays_ms = vec![0, 500, 1000, 1500]; // Delays in milliseconds + let sample_values = [100.0, 200.0, 300.0, 400.0]; + let sample_delays_ms = [0, 500, 1000, 1500]; // Delays in milliseconds let start_time = Instant::now(); @@ -1155,7 +1155,7 @@ mod tests { assert!(avg.is_finite(), "Rapid updates should maintain stability"); assert!( - avg >= 0.0 && avg <= 9.0, + (0.0..=9.0).contains(&avg), "Average should be bounded: {}", avg ); From bec479c2b6d4eb036324b241ea074f64ed5b517a Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Mon, 21 Jul 2025 09:48:49 -0400 Subject: [PATCH 11/11] update tests --- lib/llm/src/http/service/rate_limiter.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/llm/src/http/service/rate_limiter.rs b/lib/llm/src/http/service/rate_limiter.rs index be3be6703b..17cb455ed8 100644 --- a/lib/llm/src/http/service/rate_limiter.rs +++ b/lib/llm/src/http/service/rate_limiter.rs @@ -682,7 +682,7 @@ mod tests { tracker.record_value(42.0); let single_avg = tracker.get_decayed_time_weighted_average(); assert!( - (single_avg - 42.0).abs() < 1e-6, + (single_avg - 42.0).abs() < 1e-5, "Single sample average should equal sample value: {}", single_avg );