diff --git a/Cargo.lock b/Cargo.lock index de7d891621..f38923530a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1501,6 +1501,20 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.9.0" @@ -1843,6 +1857,7 @@ dependencies = [ "chrono", "criterion", "cudarc 0.16.2", + "dashmap 6.1.0", "derive-getters", "derive_builder", "dialoguer", @@ -8494,7 +8509,7 @@ dependencies = [ "asynchronous-codec", "bytes", "crossbeam-queue", - "dashmap", + "dashmap 5.5.3", "futures-channel", "futures-io", "futures-task", diff --git a/docs/guides/rate_limiting.md b/docs/guides/rate_limiting.md new file mode 100644 index 0000000000..3f36af3494 --- /dev/null +++ b/docs/guides/rate_limiting.md @@ -0,0 +1,165 @@ +# Rate Limiting Guide + +## Overview + +The Dynamo LLM service includes an intelligent rate limiter that monitors service performance metrics and automatically throttles requests when quality degrades. Unlike traditional rate limiters that count requests, this system focuses on maintaining good user experience by monitoring: + +- **Time to First Token (TTFT)** - How long users wait for the first response +- **Inter-Token Latency (ITL)** - How long between subsequent tokens + +## How It Works + +### Time-Weighted Exponential Moving Average + +The rate limiter uses a sophisticated time-weighted exponential moving average (EMA) algorithm: + +```text +average = sum(value * weight) / sum(weight) +weight = exp(-age / time_constant_secs) +``` + + +This means: +- Recent samples have higher influence on the average +- Old samples decay exponentially over time +- System "recovers" during idle periods + +### Decision Logic + +For each incoming request, the system: +1. Computes current decayed EMA for TTFT and ITL +2. Compares against configured thresholds +3. Rejects request if either threshold is exceeded +4. Logs detailed metrics for observability + +## Configuration + +### Environment Variables + +```bash +# Enable rate limiting +export DYN_RATE_LIMITER_ENABLED=true + +# TTFT threshold in milliseconds (default: 1000ms = 1s) +export DYN_RATE_LIMITER_TTFT_THRESHOLD_MS=1500 + +# ITL threshold in milliseconds (default: 10ms) +export DYN_RATE_LIMITER_ITL_THRESHOLD_MS=15 + +# Time constant for EMA decay (default: 30s) +export DYN_RATE_LIMITER_TIME_CONSTANT_SECS=60 + +# Enable per-model vs global limits (default: false) +export DYN_RATE_LIMITER_PER_MODEL_LIMITS=true +``` + +### Command Line Arguments + +```bash +dynamo-http \ + --enable-rate-limiting \ + --ttft-threshold-ms 1500 \ + --itl-threshold-ms 15 \ + --time-constant-secs 60 \ + --per-model-limits +``` + +### Programmatic Configuration + +```rust +use dynamo_llm::http::service::rate_limiter::RateLimiterConfig; + +let config = RateLimiterConfig::new( + 1500.0, // TTFT threshold (ms) + 15.0, // ITL threshold (ms) + 60.0, // Time constant (s) + true, // Per-model limits +); + +let http_service = HttpService::builder() + .with_rate_limiter_config(config) + .build()?; +``` + +## Monitoring + +### Prometheus Metrics + +The rate limiter exposes several Prometheus metrics: + +**Requests rejected by rate limiter:** + +```text +nv_llm_http_service_rate_limit_requests_total{model, endpoint, request_type, status} +``` + +**Current TTFT metrics:** + +```text +nv_llm_http_service_time_to_first_token_seconds{model} +``` + +**Current ITL metrics:** + +```text +nv_llm_http_service_inter_token_latency_seconds{model} +``` + +### Log Messages + +When requests are rejected, detailed log messages are emitted: + +```text +WARN Rate limit exceeded for model deepseek-ai/DeepSeek-R1: RateLimiterMetrics { +TTFT: TimeWeightedDiagnostics { decayed_time_weighted_average: 2.450, time_constant_secs: 30.0, last_weighted_sum: 1.245, duration_since_last_update: 0.125 }, +ITL: TimeWeightedDiagnostics { decayed_time_weighted_average: 0.025, time_constant_secs: 30.0, last_weighted_sum: 1.245, duration_since_last_update: 0.125 } +} +``` + + +## Tuning Guidelines + +### Time Constant +- **Shorter (10-30s)**: Faster reaction to load changes, more sensitive +- **Longer (60-120s)**: Smoother operation, less reactive to spikes + +### TTFT Threshold +- **Conservative (500-1000ms)**: Maintains very responsive feel +- **Moderate (1000-2000ms)**: Balances throughput with responsiveness +- **Aggressive (2000ms+)**: Prioritizes throughput over latency + +### ITL Threshold +- **Conservative (5-10ms)**: Ensures smooth streaming experience +- **Moderate (10-20ms)**: Allows some latency for higher throughput +- **Aggressive (20ms+)**: Accepts choppier streaming for max throughput + +### Per-Model vs Global +- **Per-Model**: Better for multi-tenant scenarios with different SLAs +- **Global**: Simpler for single-tenant or uniform SLA scenarios + +## Best Practices + +1. **Start Conservative**: Begin with lower thresholds and increase based on user feedback +2. **Monitor Closely**: Watch both rate limit counters and user-facing metrics +3. **Load Test**: Validate behavior under realistic load patterns +4. **Document SLAs**: Clearly communicate expected performance to users +5. **Alert on Rejections**: Set up alerts when rejection rates exceed acceptable levels + +## Troubleshooting + +### High Rejection Rates +- Check if system is genuinely overloaded +- Consider increasing thresholds temporarily +- Scale backend resources +- Investigate specific models causing issues + +### No Rejections During Overload +- Verify rate limiter is enabled +- Check threshold configuration +- Ensure metrics are being recorded properly +- Review time constant settings + +### Inconsistent Behavior +- Check if per-model limits are configured correctly +- Review metric collection for gaps +- Validate system clock stability \ No newline at end of file diff --git a/launch/dynamo-run/src/flags.rs b/launch/dynamo-run/src/flags.rs index 4bb086410a..a51b42a56a 100644 --- a/launch/dynamo-run/src/flags.rs +++ b/launch/dynamo-run/src/flags.rs @@ -18,6 +18,7 @@ use std::path::PathBuf; use clap::ValueEnum; use dynamo_llm::entrypoint::RouterConfig; +use dynamo_llm::http::service::rate_limiter::RateLimiterConfig; use dynamo_llm::kv_router::KvRouterConfig; use dynamo_llm::local_model::LocalModel; use dynamo_llm::mocker::protocols::MockEngineArgs; @@ -171,6 +172,26 @@ pub struct Flags { /// These are the command line arguments to the python engine when using `pystr` or `pytok`. #[arg(index = 2, last = true, hide = true, allow_hyphen_values = true)] pub last: Vec, + + /// Enables rate limiter config. + #[arg(long)] + pub enable_rate_limiter: Option, + + /// Time to first token threshold in seconds, for the OpenAI HTTP service rate limiter. + #[arg(long)] + pub rate_limiter_ttft_threshold_secs: Option, + + /// Inter-token latency threshold in seconds, for the OpenAI HTTP service rate limiter. + #[arg(long)] + pub rate_limiter_itl_threshold_secs: Option, + + /// Time constant for the time-weighted EMA, for the OpenAI HTTP service rate limiter. + #[arg(long)] + pub rate_limiter_time_constant_secs: Option, + + /// Whether to use per-model limits, for the OpenAI HTTP service rate limiter. + #[arg(long)] + pub rate_limiter_per_model_limits: Option, } impl Flags { @@ -240,6 +261,28 @@ impl Flags { ) } + pub fn rate_limiter_config(&self) -> RateLimiterConfig { + if self.enable_rate_limiter.is_none() { + return RateLimiterConfig::empty(); + } + + let mut builder = RateLimiterConfig::builder(); + if let Some(ttft_threshold_secs) = self.rate_limiter_ttft_threshold_secs { + builder = builder.ttft_threshold_secs(ttft_threshold_secs); + } + if let Some(itl_threshold_secs) = self.rate_limiter_itl_threshold_secs { + builder = builder.itl_threshold_secs(itl_threshold_secs); + } + if let Some(time_constant_secs) = self.rate_limiter_time_constant_secs { + builder = builder.time_constant_secs(time_constant_secs); + } + if let Some(per_model_limits) = self.rate_limiter_per_model_limits { + builder = builder.per_model_limits(per_model_limits); + } + + builder.build().unwrap_or_default() + } + /// Load extra engine arguments from a JSON file /// Returns a HashMap of parameter names to values pub fn load_extra_engine_args( diff --git a/launch/dynamo-run/src/lib.rs b/launch/dynamo-run/src/lib.rs index 5662db762c..5c888e9223 100644 --- a/launch/dynamo-run/src/lib.rs +++ b/launch/dynamo-run/src/lib.rs @@ -46,7 +46,8 @@ pub async fn run( .http_port(Some(flags.http_port)) .router_config(Some(flags.router_config())) .request_template(flags.request_template.clone()) - .migration_limit(flags.migration_limit); + .migration_limit(flags.migration_limit) + .rate_limiter_config(flags.rate_limiter_config()); // If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint. // If not, then the endpoint isn't exposed so we let LocalModel invent one. diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index c7b5697db3..fbb39c2d13 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -931,6 +931,20 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.9.0" @@ -1139,6 +1153,7 @@ dependencies = [ "candle-core", "chrono", "cudarc", + "dashmap 6.1.0", "derive-getters", "derive_builder", "dialoguer", @@ -5794,7 +5809,7 @@ dependencies = [ "asynchronous-codec", "bytes", "crossbeam-queue", - "dashmap", + "dashmap 5.5.3", "futures-channel", "futures-io", "futures-task", diff --git a/lib/bindings/python/rust/http.rs b/lib/bindings/python/rust/http.rs index 3a22092334..7a178d5db3 100644 --- a/lib/bindings/python/rust/http.rs +++ b/lib/bindings/python/rust/http.rs @@ -36,9 +36,17 @@ pub struct HttpService { #[pymethods] impl HttpService { #[new] - #[pyo3(signature = (port=None))] - pub fn new(port: Option) -> PyResult { - let builder = service_v2::HttpService::builder().port(port.unwrap_or(8080)); + #[pyo3(signature = (port=None, rate_limiter_config=None))] + pub fn new( + port: Option, + rate_limiter_config: Option, + ) -> PyResult { + let mut builder = service_v2::HttpService::builder().port(port.unwrap_or(8080)); + + if let Some(rate_limiter_config) = rate_limiter_config { + builder = builder.rate_limiter_config(rate_limiter_config.inner); + } + let inner = builder.build().map_err(to_pyerr)?; Ok(Self { inner }) } @@ -184,3 +192,30 @@ where } } } + +#[pyclass] +#[derive(Clone)] +pub struct RateLimiterConfig { + inner: dynamo_llm::http::service::rate_limiter::RateLimiterConfig, +} + +#[pymethods] +impl RateLimiterConfig { + #[new] + pub fn new( + ttft_threshold_secs: f64, + itl_threshold_secs: f64, + time_constant_secs: f64, + per_model_limits: bool, + ) -> PyResult { + let inner = dynamo_llm::http::service::rate_limiter::RateLimiterConfig::new( + ttft_threshold_secs, + itl_threshold_secs, + time_constant_secs, + per_model_limits, + ) + .map_err(to_pyerr)?; + + Ok(Self { inner }) + } +} diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index 5b548352f8..168bbdb69a 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -100,6 +100,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/lib/bindings/python/src/dynamo/_core.pyi b/lib/bindings/python/src/dynamo/_core.pyi index a32aaf4d84..09a3435a94 100644 --- a/lib/bindings/python/src/dynamo/_core.pyi +++ b/lib/bindings/python/src/dynamo/_core.pyi @@ -795,7 +795,11 @@ class HttpService: It is a OpenAI compatible http ingress into the Dynamo Distributed Runtime. """ - ... + def __init__( + self, + port: Optional[int] = None, + rate_limiter_config: Optional[RateLimiterConfig] = None + ) -> None: ... class HttpError: """ @@ -804,6 +808,19 @@ class HttpError: ... +class RateLimiterConfig: + """ + A configuration for the HTTP service rate limiter logic + """ + + def __init__( + self, + ttft_threshold_secs: Optional[float] = None, + itl_threshold_secs: Optional[float] = None, + time_constant_secs: Optional[float] = None, + per_model_limits: Optional[bool] = None + ) -> None: ... + class HttpAsyncEngine: """ An async engine for a distributed Dynamo http service. This is an extension of the diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index 355b8f4116..30c0c1849d 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -36,6 +36,10 @@ testing-nixl = ["dep:nixl-sys"] block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:ndarray", "dep:nix"] sentencepiece = ["dep:sentencepiece"] +[[bench]] +name = "rate_limiter" +harness = false + [[bench]] name = "tokenizer" harness = false @@ -109,6 +113,7 @@ tokenizers = { version = "0.21.1", default-features = false, features = [ sentencepiece = { version = "0.11.2", optional = true } # backend +dashmap = { version = "6.1.0" } galil-seiferas = { version = "0.1" } toktrie = { version = "0.6.28" } toktrie_hf_tokenizers = { version = "0.6.28" } diff --git a/lib/llm/benches/rate_limiter.rs b/lib/llm/benches/rate_limiter.rs new file mode 100644 index 0000000000..741a623c62 --- /dev/null +++ b/lib/llm/benches/rate_limiter.rs @@ -0,0 +1,336 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use dynamo_llm::http::service::rate_limiter::{ + RateLimiter, RateLimiterConfig, TimeWeightedAverageTracker, +}; +use std::sync::Arc; +use std::thread; + +// Benchmark configurations +const SAMPLE_SIZES: &[usize] = &[10, 100, 1000, 10000]; +const TIME_CONSTANTS: &[f64] = &[1.0, 10.0, 30.0, 60.0]; +const THREAD_COUNTS: &[usize] = &[1, 2, 4, 8, 16]; + +/// Benchmark recording single values to the tracker +fn bench_record_value(c: &mut Criterion) { + let mut group = c.benchmark_group("record_value"); + + for &sample_size in SAMPLE_SIZES { + group.throughput(Throughput::Elements(sample_size as u64)); + + group.bench_with_input( + BenchmarkId::new("sequential", sample_size), + &sample_size, + |b, &size| { + b.iter(|| { + let mut tracker = TimeWeightedAverageTracker::new(10.0); + for i in 0..size { + tracker.record_value(black_box(i as f64)); + } + }); + }, + ); + } + group.finish(); +} + +/// Benchmark computing time-weighted averages +fn bench_time_weighted_average(c: &mut Criterion) { + let mut group = c.benchmark_group("time_weighted_average"); + + for &sample_size in SAMPLE_SIZES { + group.throughput(Throughput::Elements(1)); // One calculation per iteration + + group.bench_with_input( + BenchmarkId::new("computation", sample_size), + &sample_size, + |b, &size| { + // Pre-populate tracker with samples + let mut tracker = TimeWeightedAverageTracker::new(10.0); + for i in 0..size { + tracker.record_value(i as f64); + } + + b.iter(|| { + black_box(tracker.get_decayed_time_weighted_average()); + }); + }, + ); + } + group.finish(); +} + +/// Benchmark different time constants impact on performance +fn bench_time_constants(c: &mut Criterion) { + let mut group = c.benchmark_group("time_constants"); + + const SAMPLE_SIZE: usize = 1000; + + for &time_constant in TIME_CONSTANTS { + group.bench_with_input( + BenchmarkId::new("record_and_compute", time_constant), + &time_constant, + |b, &tc| { + b.iter(|| { + let mut tracker = TimeWeightedAverageTracker::new(tc); + + // Record samples + for i in 0..SAMPLE_SIZE { + tracker.record_value(black_box(i as f64)); + } + + // Compute average + black_box(tracker.get_decayed_time_weighted_average()); + }); + }, + ); + } + group.finish(); +} + +/// Benchmark rate limiter decision making +fn bench_rate_limiter_decisions(c: &mut Criterion) { + let mut group = c.benchmark_group("rate_limiter_decisions"); + + let config = RateLimiterConfig::new(100.0, 10.0, 10.0, false).unwrap(); + + group.bench_function("should_reject_with_data", |b| { + let rate_limiter = RateLimiter::new(config.clone()); + + // Pre-populate with samples + for i in 0..100 { + rate_limiter.record_ttft("test-model", 50.0 + i as f64); + rate_limiter.record_itl("test-model", 5.0 + (i as f64 / 10.0)); + } + + b.iter(|| { + black_box(rate_limiter.should_reject(black_box("test-model"))); + }); + }); + + group.bench_function("record_ttft", |b| { + let rate_limiter = RateLimiter::new(config.clone()); + let mut counter = 0; + + b.iter(|| { + rate_limiter.record_ttft(black_box("test-model"), black_box(counter as f64)); + counter += 1; + }); + }); + + group.bench_function("record_itl", |b| { + let rate_limiter = RateLimiter::new(config.clone()); + let mut counter = 0; + + b.iter(|| { + rate_limiter.record_itl(black_box("test-model"), black_box(counter as f64)); + counter += 1; + }); + }); + + group.finish(); +} + +/// Benchmark concurrent access patterns +fn bench_concurrent_access(c: &mut Criterion) { + let mut group = c.benchmark_group("concurrent_access"); + + for &thread_count in THREAD_COUNTS { + group.throughput(Throughput::Elements(thread_count as u64 * 100)); + + group.bench_with_input( + BenchmarkId::new("multi_thread_records", thread_count), + &thread_count, + |b, &num_threads| { + b.iter(|| { + let config = RateLimiterConfig::new(1000.0, 10.0, 30.0, false).unwrap(); + let rate_limiter = Arc::new(RateLimiter::new(config)); + + let handles: Vec<_> = (0..num_threads) + .map(|thread_id| { + let limiter = rate_limiter.clone(); + thread::spawn(move || { + for i in 0..100 { + let value = (thread_id * 100 + i) as f64; + limiter.record_ttft("test-model", value); + limiter.record_itl("test-model", value / 10.0); + + // Some threads check rejection status + if i % 10 == 0 { + black_box(limiter.should_reject("test-model")); + } + } + }) + }) + .collect(); + + for handle in handles { + handle.join().unwrap(); + } + + black_box(rate_limiter); + }); + }, + ); + } + group.finish(); +} + +/// Benchmark memory allocation patterns +fn bench_memory_patterns(c: &mut Criterion) { + let mut group = c.benchmark_group("memory_patterns"); + + group.bench_function("memory_bounded_growth", |b| { + b.iter(|| { + let mut tracker = TimeWeightedAverageTracker::new(10.0); + + // Add way more samples than max_samples to test memory bounds + for i in 0..1000 { + tracker.record_value(black_box(i as f64)); + + // Occasionally compute average to trigger cleanup + if i % 50 == 0 { + black_box(tracker.get_decayed_time_weighted_average()); + } + } + + black_box(tracker); + }); + }); + + group.bench_function("per_model_isolation", |b| { + let config = RateLimiterConfig::new(1000.0, 10.0, 30.0, true).unwrap(); + + b.iter(|| { + let rate_limiter = RateLimiter::new(config.clone()); + + // Simulate multiple models + for model_id in 0..10 { + let model_name = format!("model-{}", model_id); + for i in 0..50 { + rate_limiter.record_ttft(&model_name, i as f64); + rate_limiter.record_itl(&model_name, (i as f64) / 10.0); + } + black_box(rate_limiter.should_reject(&model_name)); + } + + black_box(rate_limiter); + }); + }); + + group.finish(); +} + +/// Benchmark edge cases and stress scenarios +fn bench_edge_cases(c: &mut Criterion) { + let mut group = c.benchmark_group("edge_cases"); + + group.bench_function("rapid_fire_records", |b| { + b.iter(|| { + let mut tracker = TimeWeightedAverageTracker::new(1.0); + + // Rapid fire recording without any delays + for i in 0..5000 { + tracker.record_value(black_box(i as f64)); + } + + black_box(tracker.get_decayed_time_weighted_average()); + }); + }); + + group.bench_function("alternating_high_low_values", |b| { + b.iter(|| { + let mut tracker = TimeWeightedAverageTracker::new(5.0); + + // Alternating between very high and very low values + for i in 0..500 { + let value = if i % 2 == 0 { 1000000.0 } else { 0.001 }; + tracker.record_value(black_box(value)); + } + + black_box(tracker.get_decayed_time_weighted_average()); + }); + }); + + group.bench_function("very_old_samples", |b| { + b.iter(|| { + let mut tracker = TimeWeightedAverageTracker::new(0.0001); // Very short time constant + + // Add some samples + for i in 0..100 { + tracker.record_value(black_box(i as f64)); + } + + // Add fresh samples + for i in 100..200 { + tracker.record_value(black_box(i as f64)); + } + + black_box(tracker.get_decayed_time_weighted_average()); + }); + }); + + group.finish(); +} + +/// Comprehensive benchmark comparing different configurations +fn bench_configuration_comparison(c: &mut Criterion) { + let mut group = c.benchmark_group("configuration_comparison"); + + let configs = vec![ + ( + "aggressive", + RateLimiterConfig::new(1000.0, 10.0, 1.0, false).unwrap(), + ), + ( + "balanced", + RateLimiterConfig::new(1000.0, 10.0, 10.0, false).unwrap(), + ), + ( + "conservative", + RateLimiterConfig::new(1000.0, 10.0, 60.0, false).unwrap(), + ), + ]; + + for (name, config) in configs { + group.bench_with_input( + BenchmarkId::new("full_workflow", name), + &config, + |b, config| { + b.iter(|| { + let rate_limiter = RateLimiter::new(config.clone()); + + // Simulate realistic usage pattern + for i in 0..200 { + rate_limiter.record_ttft("model", black_box(50.0 + (i as f64))); + rate_limiter.record_itl("model", black_box(5.0 + (i as f64 / 10.0))); + + if i % 20 == 0 { + black_box(rate_limiter.should_reject("model")); + } + } + + black_box(rate_limiter); + }); + }, + ); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_record_value, + bench_time_weighted_average, + bench_time_constants, + bench_rate_limiter_decisions, + bench_concurrent_access, + bench_memory_patterns, + bench_edge_cases, + bench_configuration_comparison +); + +criterion_main!(benches); diff --git a/lib/llm/src/entrypoint/input/http.rs b/lib/llm/src/entrypoint/input/http.rs index 01bfa61c69..ddabf783e9 100644 --- a/lib/llm/src/entrypoint/input/http.rs +++ b/lib/llm/src/entrypoint/input/http.rs @@ -27,6 +27,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul .enable_chat_endpoints(true) .enable_cmpl_endpoints(true) .enable_embeddings_endpoints(true) + .rate_limiter_config(engine_config.local_model().rate_limiter_config()) .with_request_template(engine_config.local_model().request_template()) .build()?; match engine_config { diff --git a/lib/llm/src/http/service.rs b/lib/llm/src/http/service.rs index 7f163a200f..44b86d0ab6 100644 --- a/lib/llm/src/http/service.rs +++ b/lib/llm/src/http/service.rs @@ -24,6 +24,7 @@ pub mod disconnect; pub mod error; pub mod health; pub mod metrics; +pub mod rate_limiter; pub mod service_v2; pub use axum; diff --git a/lib/llm/src/http/service/metrics.rs b/lib/llm/src/http/service/metrics.rs index 181763895c..a0ce4654a4 100644 --- a/lib/llm/src/http/service/metrics.rs +++ b/lib/llm/src/http/service/metrics.rs @@ -10,6 +10,8 @@ use std::{ pub use prometheus::Registry; +use crate::http::service::rate_limiter::RateLimiter; + use super::RouteDoc; /// Value for the `status` label in the request counter for successful requests @@ -18,6 +20,9 @@ pub const REQUEST_STATUS_SUCCESS: &str = "success"; /// Value for the `status` label in the request counter if the request failed pub const REQUEST_STATUS_ERROR: &str = "error"; +/// Value for the `status` label in the request counter if the request was rejected by the rate limiter +pub const REQUEST_STATUS_REJECTED: &str = "rejected"; + /// Partial value for the `type` label in the request counter for streaming requests pub const REQUEST_TYPE_STREAM: &str = "stream"; @@ -25,6 +30,7 @@ pub const REQUEST_TYPE_STREAM: &str = "stream"; pub const REQUEST_TYPE_UNARY: &str = "unary"; pub struct Metrics { + rate_limit_requests_counter: IntCounterVec, request_counter: IntCounterVec, inflight_gauge: IntGaugeVec, request_duration: HistogramVec, @@ -32,6 +38,8 @@ pub struct Metrics { output_sequence_length: HistogramVec, time_to_first_token: HistogramVec, inter_token_latency: HistogramVec, + rate_limit_ema_ttft: HistogramVec, + rate_limit_ema_itl: HistogramVec, } /// RAII object for inflight gauge and request counters @@ -72,15 +80,27 @@ pub enum RequestType { Stream, } +impl RequestType { + pub fn from_streaming_boolean(is_streaming: bool) -> Self { + if is_streaming { + RequestType::Stream + } else { + RequestType::Unary + } + } +} + /// Status pub enum Status { Success, Error, + Rejected, } /// Track response-specific metrics pub struct ResponseMetricCollector { metrics: Arc, + rate_limiter: Arc, model: String, start_time: Instant, // we use is_first_token to distinguish TTFT from ITL. It is true by default and @@ -108,7 +128,19 @@ impl Metrics { /// - `{prefix}_http_service_output_sequence_tokens` - HistogramVec for output sequence length in tokens /// - `{prefix}_http_service_time_to_first_token_seconds` - HistogramVec for time to first token in seconds /// - `{prefix}_http_service_inter_token_latency_seconds` - HistogramVec for inter-token latency in seconds + /// - `{prefix}_http_service_rate_limit_requests_total` - IntCounterVec for the total number of requests rejected by the rate limiter + /// - `{prefix}_http_service_rate_limit_ema_ttft_seconds` - HistogramVec for time to first token in seconds + /// - `{prefix}_http_service_rate_limit_ema_itl_seconds` - HistogramVec for inter-token latency in seconds pub fn new(prefix: &str) -> Self { + let rate_limit_requests_counter = IntCounterVec::new( + Opts::new( + format!("{}_http_service_rate_limit_requests_total", prefix), + "Total number of requests rejected by the rate limiter", + ), + &["model", "endpoint", "request_type", "status"], + ) + .unwrap(); + let request_counter = IntCounterVec::new( Opts::new( format!("{}_http_service_requests_total", prefix), @@ -189,7 +221,33 @@ impl Metrics { ) .unwrap(); + let rate_limit_ema_ttft = HistogramVec::new( + HistogramOpts::new( + format!("{}_http_service_rate_limit_ema_ttft_seconds", prefix), + "Time to first token in seconds", + ) + .buckets(vec![ + 0.0, 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, + 60.0, 120.0, 240.0, 480.0, + ]), + &["model"], + ) + .unwrap(); + + let rate_limit_ema_itl = HistogramVec::new( + HistogramOpts::new( + format!("{}_http_service_rate_limit_ema_itl_seconds", prefix), + "Inter-token latency in seconds", + ) + .buckets(vec![ + 0.0, 0.001, 0.005, 0.01, 0.015, 0.02, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.0, + ]), + &["model"], + ) + .unwrap(); + Metrics { + rate_limit_requests_counter, request_counter, inflight_gauge, request_duration, @@ -197,9 +255,33 @@ impl Metrics { output_sequence_length, time_to_first_token, inter_token_latency, + rate_limit_ema_ttft, + rate_limit_ema_itl, } } + /// Get the number of requests rejected by the rate limiter for the given dimensions: + /// - model + /// - endpoint (completions/chat_completions) + /// - request type (unary/stream) + /// - status (success/error) + pub fn get_rate_limit_requests_counter( + &self, + model: &str, + endpoint: &Endpoint, + request_type: &RequestType, + status: &Status, + ) -> u64 { + self.rate_limit_requests_counter + .with_label_values(&[ + model, + endpoint.as_str(), + request_type.as_str(), + status.as_str(), + ]) + .get() + } + /// Get the number of successful requests for the given dimensions: /// - model /// - endpoint (completions/chat_completions) @@ -222,6 +304,28 @@ impl Metrics { .get() } + /// Increment the counter for requests rejected by the rate limiter for the given dimensions: + /// - model + /// - endpoint (completions/chat_completions) + /// - request type (unary/stream) + /// - status (success/error) + pub fn inc_rate_limit_requests_counter( + &self, + model: &str, + endpoint: &Endpoint, + request_type: &RequestType, + status: &Status, + ) { + self.rate_limit_requests_counter + .with_label_values(&[ + model, + endpoint.as_str(), + request_type.as_str(), + status.as_str(), + ]) + .inc() + } + /// Increment the counter for requests for the given dimensions: /// - model /// - endpoint (completions/chat_completions) @@ -258,6 +362,7 @@ impl Metrics { } pub fn register(&self, registry: &Registry) -> Result<(), prometheus::Error> { + registry.register(Box::new(self.rate_limit_requests_counter.clone()))?; registry.register(Box::new(self.request_counter.clone()))?; registry.register(Box::new(self.inflight_gauge.clone()))?; registry.register(Box::new(self.request_duration.clone()))?; @@ -265,6 +370,8 @@ impl Metrics { registry.register(Box::new(self.output_sequence_length.clone()))?; registry.register(Box::new(self.time_to_first_token.clone()))?; registry.register(Box::new(self.inter_token_latency.clone()))?; + registry.register(Box::new(self.rate_limit_ema_ttft.clone()))?; + registry.register(Box::new(self.rate_limit_ema_itl.clone()))?; Ok(()) } @@ -294,8 +401,26 @@ impl Metrics { } /// Create a new [`ResponseMetricCollector`] for collecting per-response metrics (i.e., TTFT, ITL) - pub fn create_response_collector(self: Arc, model: &str) -> ResponseMetricCollector { - ResponseMetricCollector::new(self, model.to_string().to_lowercase()) + pub fn create_response_collector( + self: Arc, + model: &str, + rate_limiter: Arc, + ) -> ResponseMetricCollector { + ResponseMetricCollector::new(self, model.to_string().to_lowercase(), rate_limiter) + } + + /// Record the time to first token for the given model, endpoint, and request type + pub fn record_rate_limit_ttft(&self, ttft_ema: f64, model: &str) { + self.rate_limit_ema_ttft + .with_label_values(&[model]) + .observe(ttft_ema); + } + + /// Record the inter-token latency for the given model, endpoint, and request type + pub fn record_rate_limit_itl(&self, itl_ema: f64, model: &str) { + self.rate_limit_ema_itl + .with_label_values(&[model]) + .observe(itl_ema); } } @@ -387,14 +512,16 @@ impl Status { match self { Status::Success => REQUEST_STATUS_SUCCESS, Status::Error => REQUEST_STATUS_ERROR, + Status::Rejected => REQUEST_STATUS_REJECTED, } } } impl ResponseMetricCollector { - fn new(metrics: Arc, model: String) -> Self { + fn new(metrics: Arc, model: String, rate_limiter: Arc) -> Self { ResponseMetricCollector { metrics, + rate_limiter, model, is_first_token: true, last_response_time: None, @@ -425,6 +552,7 @@ impl ResponseMetricCollector { .time_to_first_token .with_label_values(&[&self.model]) .observe(ttft); + self.rate_limiter.record_ttft(&self.model, ttft); // Publish ISL // TODO: publish ISL as soon as the tokenization process completes @@ -445,6 +573,7 @@ impl ResponseMetricCollector { .with_label_values(&[&self.model]) .observe(itl); } + self.rate_limiter.record_itl(&self.model, itl); } self.last_response_time = Some(current_duration); diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index 2c23d84f7a..3fc669bd0f 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -30,15 +30,21 @@ use super::{ metrics::{Endpoint, ResponseMetricCollector}, service_v2, RouteDoc, }; -use crate::preprocessor::LLMMetricAnnotation; -use crate::protocols::openai::{ - chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionResponse}, - completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, - embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, - responses::{NvCreateResponse, NvResponse}, -}; use crate::request_template::RequestTemplate; use crate::types::Annotated; +use crate::{ + http::service::metrics::{RequestType, Status}, + preprocessor::LLMMetricAnnotation, +}; +use crate::{ + http::service::rate_limiter::ShouldRejectResult, + protocols::openai::{ + chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionResponse}, + completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, + embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, + responses::{NvCreateResponse, NvResponse}, + }, +}; pub const DYNAMO_REQUEST_ID_HEADER: &str = "x-dynamo-request-id"; @@ -101,6 +107,17 @@ impl ErrorMessage { ) } + /// Rate Limit Exceeded + /// Return this error when the request is rejected due to rate limiting. + pub fn rate_limit_exceeded(msg: &str) -> ErrorResponse { + ( + StatusCode::TOO_MANY_REQUESTS, + Json(ErrorMessage { + error: msg.to_string(), + }), + ) + } + /// The OAI endpoints call an [`dynamo.runtime::engine::AsyncEngine`] which are specialized to return /// an [`anyhow::Error`]. This method will convert the [`anyhow::Error`] into an [`HttpError`]. /// If successful, it will return the [`HttpError`] as an [`ErrorMessage::internal_server_error`] @@ -206,6 +223,14 @@ async fn completions( // return a 503 if the service is not ready check_ready(&state)?; + // Rate limit check + should_reject_request( + &state, + &request.inner.model, + &Endpoint::Completions, + &RequestType::from_streaming_boolean(request.inner.stream.unwrap_or(false)), + )?; + // todo - extract distributed tracing id and context id from headers let request_id = uuid::Uuid::new_v4().to_string(); @@ -233,7 +258,9 @@ async fn completions( .metrics_clone() .create_inflight_guard(model, Endpoint::Completions, streaming); - let mut response_collector = state.metrics_clone().create_response_collector(model); + let mut response_collector = state + .metrics_clone() + .create_response_collector(model, state.rate_limiter_clone()); // prepare to process any annotations let annotations = request.annotations(); @@ -405,6 +432,14 @@ async fn chat_completions( // return a 503 if the service is not ready check_ready(&state)?; + // Rate limit check + should_reject_request( + &state, + &request.inner.model, + &Endpoint::ChatCompletions, + &RequestType::from_streaming_boolean(request.inner.stream.unwrap_or(false)), + )?; + let request_id = request.id().to_string(); // Handle unsupported fields - if Some(resp) is returned by @@ -453,7 +488,9 @@ async fn chat_completions( .metrics_clone() .create_inflight_guard(model, Endpoint::ChatCompletions, streaming); - let mut response_collector = state.metrics_clone().create_response_collector(model); + let mut response_collector = state + .metrics_clone() + .create_response_collector(model, state.rate_limiter_clone()); tracing::trace!("Issuing generate call for chat completions"); let annotations = request.annotations(); @@ -602,6 +639,15 @@ async fn responses( // return a 503 if the service is not ready check_ready(&state)?; + // Rate limit check + // TODO: handle streaming, currently just unary + should_reject_request( + &state, + &request.inner.model, + &Endpoint::Responses, + &RequestType::Unary, + )?; + // Handle unsupported fields - if Some(resp) is returned by validate_unsupported_fields, // then a field was used that is unsupported. We will log an error message // and early return a 501 NOT_IMPLEMENTED status code. Otherwise, proceeed. @@ -665,7 +711,9 @@ async fn responses( .metrics_clone() .create_inflight_guard(model, Endpoint::Responses, false); - let _response_collector = state.metrics_clone().create_response_collector(model); + let _response_collector = state + .metrics_clone() + .create_response_collector(model, state.rate_limiter_clone()); tracing::trace!("Issuing generate call for chat completions"); @@ -705,6 +753,44 @@ async fn responses( Ok(Json(response).into_response()) } +pub fn should_reject_request( + state: &Arc, + model: &str, + endpoint: &Endpoint, + request_type: &RequestType, +) -> Result<(), ErrorResponse> { + if !state.rate_limiter().is_enabled() { + return Ok(()); + } + + let ShouldRejectResult { + should_reject, + decayed_ttft_ema, + decayed_itl_ema, + } = state.rate_limiter().should_reject(model); + + state + .metrics_clone() + .record_rate_limit_ttft(decayed_ttft_ema, model); + state + .metrics_clone() + .record_rate_limit_itl(decayed_itl_ema, model); + + if should_reject { + state.metrics_clone().inc_rate_limit_requests_counter( + model, + endpoint, + request_type, + &Status::Rejected, + ); + return Err(ErrorMessage::rate_limit_exceeded(&format!( + "Too many requests: rate limit exceeded for current request and model: {model}. Please retry later." + ))); + } + + Ok(()) +} + pub fn validate_response_input_is_text_only( request: &NvCreateResponse, ) -> Option { diff --git a/lib/llm/src/http/service/rate_limiter.rs b/lib/llm/src/http/service/rate_limiter.rs new file mode 100644 index 0000000000..17cb455ed8 --- /dev/null +++ b/lib/llm/src/http/service/rate_limiter.rs @@ -0,0 +1,1204 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Rate limiter implementation for the OpenAI API compatible HTTP service. +//! +//! The rate limiter is used to limit the rate of requests to the HTTP service. +//! It is used to prevent abuse of the service, under heavy load, +//! and to ensure that the service is available to all users. The system values +//! 'good-put' (that is, the throughput of the system under good performance metrics, in the form of +//! time to first token and inter-token latency) over 'throughput' (that is, the total amount of +//! tokens processed), and the rate limiter is used to ensure that the service is available to all +//! users, even under heavy load. +//! +//! The rate limiter is implemented using a time-weighted exponential moving average (EMA). +//! The time-weighted average is computed using the following formula: +//! +//! ```text +//! average = sum(value * weight) / sum(weight) +//! ``` +//! +//! Where `weight` is the weight of the sample based on the age of the sample and the time constant: +//! +//! ```text +//! age = now - record_time +//! weight = exp(-age / time_constant_secs) +//! ``` +//! +//! Where `now` is the current time, `record_time` is the time the sample was recorded, +//! and `time_constant_secs` is the time constant for the time-weighted average. +//! +//! Moreover, we decay the average to account for the time elapsed since the last update. +//! This models "system recovery" during idle time. This is done by multiplying the average by the +//! decay factor: +//! +//! ```text +//! decayed_average = average * exp(-time_elapsed / time_constant_secs) +//! ``` + +use std::time::Instant; + +use anyhow::Result; +use dashmap::DashMap; +use derive_builder::Builder; +use validator::Validate; + +/// Configuration for the rate limiter +#[derive(Debug, Clone, Builder, Validate)] +#[builder(pattern = "owned")] +pub struct RateLimiterConfig { + /// Threshold for the time to first token metric, + /// which defines the maximum allowed time to first token + /// in seconds. Any recorded time to first token above this threshold + /// will likely trigger a rate limit rejection for the next incoming request. + #[builder(default = "1.0")] + #[validate(range(min = 1e-2))] + ttft_threshold_secs: f64, + + /// Threshold for the inter-token latency metric, + /// which defines the maximum allowed inter-token latency + /// in seconds. Any recorded inter-token latency above this threshold + /// will likely trigger a rate limit rejection for the next incoming request. + #[builder(default = "0.1")] + #[validate(range(min = 1e-4))] + itl_threshold_secs: f64, + + /// Time constant for the time-weighted EMA, + /// that is, the time constant for the exponential moving average + /// of the time-weighted average. + #[builder(default = "15.0")] + #[validate(range(min = 1e-2))] + time_constant_secs: f64, + + /// Whether to use per-model limits, that is, + /// to track rate limit metrics for each model separately + #[builder(default = "true")] + per_model_limits: bool, + + /// Whether the rate limiter is enabled + #[builder(default = "false")] + is_enabled: bool, +} + +impl RateLimiterConfig { + pub fn new( + ttft_threshold_secs: f64, + itl_threshold_secs: f64, + time_constant_secs: f64, + per_model_limits: bool, + ) -> Result { + let config: RateLimiterConfig = Self { + ttft_threshold_secs, + itl_threshold_secs, + time_constant_secs, + per_model_limits, + is_enabled: true, + }; + + config + .validate() + .map_err(|e| anyhow::anyhow!("Invalid rate limiter config: {}", e))?; + + Ok(config) + } + + pub fn empty() -> Self { + Self { + ttft_threshold_secs: 0.0, + itl_threshold_secs: 0.0, + time_constant_secs: 0.001, + per_model_limits: false, + is_enabled: false, + } + } + + pub fn is_enabled(&self) -> bool { + self.is_enabled + } + + pub fn builder() -> RateLimiterConfigBuilder { + RateLimiterConfigBuilder::default() + } +} + +impl Default for RateLimiterConfig { + fn default() -> Self { + Self { + ttft_threshold_secs: 1.0, // 1s + itl_threshold_secs: 0.1, // 100ms + time_constant_secs: 30.0, // 30s + per_model_limits: false, + is_enabled: true, + } + } +} + +/// Tracks recent samples to compute time-weighted averages of a metric. Formally, +/// the time-weighted average is defined as: +/// +/// ```text +/// average = sum(value * weight) / sum(weight) +/// ``` +/// +/// Where `weight` is the weight of the sample based on the age of the sample and the time constant: +/// +/// ```text +/// age = now - record_time +/// weight = exp(-age / time_constant_secs) +/// ``` +/// +/// Where `now` is the current time, `record_time` is the time the sample was recorded, +/// and `time_constant_secs` is the time constant for the time-weighted average. +/// In this way, more recent samples have a higher weight than older samples, the latter +/// decaying exponentially towards zero (making it less impactful for the current average calculation). +/// +/// In order to compute the time-weighted average more efficiently, we leverage the well +/// known property of the exponential function: +/// +/// ```text +/// exp(x) = exp(y) * exp(x - y) +/// ``` +/// +/// This allows us to compute the time-weighted average, recursively, in a single pass, +/// (see Markov's property) as follows: +/// +/// ```text +/// previous_weight_total = sum(weight) +/// updated_factor = 1 / (1 + previous_weight_total * exp(-age / time_constant_secs)) +/// average(now) = average(last_time) * updated_factor + value * (1 - updated_factor) +/// ``` +#[derive(Debug)] +pub struct TimeWeightedAverageTracker { + previous_weighted_average: f64, + previous_total_weight: f64, + previous_observed_time: Instant, + time_constant_secs: f64, +} + +impl TimeWeightedAverageTracker { + pub fn new(time_constant_secs: f64) -> Self { + let now = Instant::now(); + Self { + previous_weighted_average: 0., + previous_total_weight: 0., + previous_observed_time: now, + time_constant_secs, + } + } + + /// Record a new value to the tracker. + pub fn record_value(&mut self, value: f64) { + let now = Instant::now(); + if self.previous_weighted_average == 0. && self.previous_total_weight == 0. { + // First sample + self.previous_weighted_average = value; + self.previous_total_weight = 1.; + } else { + let time_elapsed = now + .duration_since(self.previous_observed_time) + .as_secs_f64(); + let decay_factor = (-time_elapsed / self.time_constant_secs).exp(); + + // Update the weighted average, using recursive EMA formula + self.previous_total_weight = 1. + self.previous_total_weight * decay_factor; + let alpha = 1. / self.previous_total_weight; + self.previous_weighted_average = + alpha * value + (1. - alpha) * self.previous_weighted_average; + } + + self.previous_observed_time = now; + } + + /// Get the current time-weighted average, decayed to account for the time elapsed since the last update. + pub fn get_decayed_time_weighted_average(&self) -> f64 { + let now = Instant::now(); + let time_elapsed = now + .duration_since(self.previous_observed_time) + .as_secs_f64(); + let decay_factor = (-time_elapsed / self.time_constant_secs).exp(); + self.previous_weighted_average * decay_factor + } +} + +#[derive(Debug)] +struct ModelMetrics { + ttft_tracker: TimeWeightedAverageTracker, + itl_tracker: TimeWeightedAverageTracker, +} + +impl ModelMetrics { + fn new(config: &RateLimiterConfig) -> Self { + let ttft_tracker = TimeWeightedAverageTracker::new(config.time_constant_secs); + let itl_tracker = TimeWeightedAverageTracker::new(config.time_constant_secs); + + Self { + ttft_tracker, + itl_tracker, + } + } +} + +#[derive(Debug, Clone)] +pub struct RateLimiterMetrics { + pub ttft_diagnostics: TimeWeightedDiagnostics, + pub itl_diagnostics: TimeWeightedDiagnostics, +} + +impl std::fmt::Display for RateLimiterMetrics { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "RateLimiterMetrics {{\n TTFT: {},\n ITL: {}\n}}", + self.ttft_diagnostics, self.itl_diagnostics + ) + } +} + +#[derive(Debug, Clone)] +pub struct TimeWeightedDiagnostics { + pub decayed_time_weighted_average: f64, + pub time_constant_secs: f64, + pub previous_weighted_sum: f64, + pub previous_observed_time: Instant, +} + +impl std::fmt::Display for TimeWeightedDiagnostics { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "TimeWeightedDiagnostics {{ \ + decayed_time_weighted_average: {:.3}, \ + time_constant_secs: {:.1}, \ + previous_weighted_sum: {:.3}, \ + duration_since_last_update: {:?} \ + }}", + self.decayed_time_weighted_average, + self.time_constant_secs, + self.previous_weighted_sum, + self.previous_observed_time.elapsed().as_secs_f64() + ) + } +} + +pub struct RateLimiter { + config: RateLimiterConfig, + model_metrics: DashMap, +} + +impl RateLimiter { + pub fn new(config: RateLimiterConfig) -> Self { + Self { + config, + model_metrics: DashMap::new(), + } + } + + pub fn is_enabled(&self) -> bool { + self.config.is_enabled + } + + #[inline] + fn get_model_key(&self, model: &str) -> String { + if self.config.per_model_limits { + model.to_string() + } else { + "global".to_string() + } + } + + /// Record the time to first token metric for a given model + pub fn record_ttft(&self, model: &str, ttft_ms: f64) { + let model_key = self.get_model_key(model); + let mut model_metrics = self + .model_metrics + .entry(model_key) + .or_insert_with(|| ModelMetrics::new(&self.config)); + + model_metrics.ttft_tracker.record_value(ttft_ms); + } + + /// Record the inter-token latency metric for a given model + pub fn record_itl(&self, model: &str, itl_ms: f64) { + let model_key = self.get_model_key(model); + let mut model_metrics = self + .model_metrics + .entry(model_key) + .or_insert_with(|| ModelMetrics::new(&self.config)); + + model_metrics.itl_tracker.record_value(itl_ms); + } + + /// Check if the request should be rejected based on the cached metrics + /// + /// Returns true if the request should be rejected, false otherwise + pub fn should_reject(&self, model: &str) -> ShouldRejectResult { + let model_key = self.get_model_key(model); + let model_metrics = self.model_metrics.get(&model_key); + + let Some(model_metrics) = model_metrics else { + return ShouldRejectResult { + should_reject: false, + decayed_ttft_ema: 0.0, + decayed_itl_ema: 0.0, + }; + }; + + // Get decayed time-weighted EMA values + let decayed_ttft_ema = model_metrics + .ttft_tracker + .get_decayed_time_weighted_average(); + let decayed_itl_ema = model_metrics + .itl_tracker + .get_decayed_time_weighted_average(); + + drop(model_metrics); + + let ttft_exceeded = self.config.ttft_threshold_secs < decayed_ttft_ema; + let itl_exceeded = self.config.itl_threshold_secs < decayed_itl_ema; + + if ttft_exceeded || itl_exceeded { + let rate_limiter_metrics = self.get_metrics(&model_key); + self.log_metrics(model, rate_limiter_metrics, true); + return ShouldRejectResult { + should_reject: true, + decayed_ttft_ema, + decayed_itl_ema, + }; + } + + if decayed_ttft_ema > self.config.ttft_threshold_secs * 0.9 + || decayed_itl_ema > self.config.itl_threshold_secs * 0.9 + { + let rate_limiter_metrics = self.get_metrics(&model_key); + self.log_metrics(model, rate_limiter_metrics, false); + } + + ShouldRejectResult { + should_reject: false, + decayed_ttft_ema, + decayed_itl_ema, + } + } + + /// Get current metrics and diagnostics for current model + #[inline] + fn get_metrics(&self, model_key: &str) -> RateLimiterMetrics { + let model_metrics = self.model_metrics.get(model_key).unwrap(); + let decayed_ttft_ema = model_metrics + .ttft_tracker + .get_decayed_time_weighted_average(); + let decayed_itl_ema = model_metrics + .itl_tracker + .get_decayed_time_weighted_average(); + let ttft_previous_weighted_sum = model_metrics.ttft_tracker.previous_total_weight; + let itl_previous_weighted_sum = model_metrics.itl_tracker.previous_total_weight; + let ttft_previous_observed_time = model_metrics.ttft_tracker.previous_observed_time; + let itl_previous_observed_time = model_metrics.itl_tracker.previous_observed_time; + + RateLimiterMetrics { + ttft_diagnostics: TimeWeightedDiagnostics { + decayed_time_weighted_average: decayed_ttft_ema, + time_constant_secs: self.config.time_constant_secs, + previous_weighted_sum: ttft_previous_weighted_sum, + previous_observed_time: ttft_previous_observed_time, + }, + itl_diagnostics: TimeWeightedDiagnostics { + decayed_time_weighted_average: decayed_itl_ema, + time_constant_secs: self.config.time_constant_secs, + previous_weighted_sum: itl_previous_weighted_sum, + previous_observed_time: itl_previous_observed_time, + }, + } + } + + fn log_metrics(&self, model: &str, metrics: RateLimiterMetrics, has_exceeded: bool) { + if has_exceeded { + tracing::warn!( + model = model, + ttft_threshold_secs = self.config.ttft_threshold_secs, + itl_threshold_secs = self.config.itl_threshold_secs, + "Rate limit exceeded for model {model}: {metrics}", + metrics = metrics, + ); + } else { + tracing::info!( + model = model, + ttft_threshold_secs = self.config.ttft_threshold_secs, + itl_threshold_secs = self.config.itl_threshold_secs, + "Approaching rate limit thresholds. Current rate limit metrics for model {model}: {metrics}", + metrics = metrics, + ); + } + } +} + +pub struct ShouldRejectResult { + pub should_reject: bool, + pub decayed_ttft_ema: f64, + pub decayed_itl_ema: f64, +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::Ordering; + use std::sync::{atomic::AtomicUsize, Arc}; + use std::thread; + use std::time::Duration; + + #[test] + fn test_simple_time_weighted_average_tracker() { + const TIME_CONSTANT_SECS: f64 = 1.0; // Short time constant + + const SLEEP_DURATION_MS: u64 = 100; + + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Add samples with increasing delays + tracker.record_value(100.0); + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); + tracker.record_value(200.0); + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); + tracker.record_value(300.0); + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); + tracker.record_value(400.0); + thread::sleep(Duration::from_millis(20 * SLEEP_DURATION_MS)); // Long gap + tracker.record_value(500.0); + + let avg = tracker.get_decayed_time_weighted_average(); + assert!(avg > 0.0, "Average should be positive"); + } + + #[test] + fn test_edge_case_all_samples_below_threshold() { + const TIME_CONSTANT_SECS: f64 = 0.1; + const SLEEP_DURATION_MS: u64 = 2_000; + const EPSILON: f64 = 1e-5; + + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + tracker.record_value(100.0); + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); + + let avg = tracker.get_decayed_time_weighted_average(); + + // Should return a close to 0.0 when time constant is small and time passed is large + assert!(avg < EPSILON, "Average should be 0.0: {}", avg); + } + + #[test] + fn test_edge_case_single_sample() { + const TIME_CONSTANT_SECS: f64 = 10.; + const EPSILON: f64 = 0.5; // exp(-0.01) ~= 0.99 + + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + tracker.record_value(42.); + thread::sleep(Duration::from_millis(100)); + + let avg = tracker.get_decayed_time_weighted_average(); + + assert!( + (avg - 42.).abs() < EPSILON, + "Average should be close to 42.0: {}", + avg + ); + } + + #[test] + fn test_time_weighted_average_tracker_correctness() { + const TIME_CONSTANT_SECS: f64 = 10.0; + const NUM_SAMPLES: usize = 100; + const SLEEP_DURATION_MS: u64 = 1; + + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Record old sample + tracker.record_value(1000.0); + thread::sleep(Duration::from_millis(1_000 * SLEEP_DURATION_MS)); + + // Add more recent samples with lower values + for _ in 0..NUM_SAMPLES { + tracker.record_value(100.0); + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); + } + + let avg = tracker.get_decayed_time_weighted_average(); + assert!( + avg < 500.0, + "Average should be dominated by recent samples: {}", + avg + ); + assert!( + avg > 100.0, + "Average should still be influenced by old sample: {}", + avg + ); + } + + #[test] + fn test_time_weighted_average_quantitative_analysis() { + const TIME_CONSTANT_SECS: f64 = 2.0; // 2 second time constant + const EPSILON: f64 = 0.05; // 5% tolerance for timing precision + const SLEEP_DURATION_MS: u64 = 100; // 100ms + + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Record samples with known values and controlled timing + let sample_values = [100.0, 200.0, 300.0, 400.0]; + let sample_delays_ms = [0, 500, 1000, 1500]; // Delays in milliseconds + + let start_time = Instant::now(); + + // Record first sample immediately + tracker.record_value(sample_values[0]); + + // Record subsequent samples with known delays + for i in 1..sample_values.len() { + thread::sleep(Duration::from_millis( + sample_delays_ms[i] - sample_delays_ms[i - 1], + )); + tracker.record_value(sample_values[i]); + } + + // Wait a bit more, then calculate + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); + let calculation_time = Instant::now(); + + // Calculate expected weighted average manually + let total_elapsed = calculation_time.duration_since(start_time); + let mut expected_weighted_sum = 0.0; + let mut expected_total_weight = 0.0; + + for i in 0..sample_values.len() { + // Age of this sample = total_elapsed - delay_when_recorded + let sample_age_secs = + total_elapsed.as_secs_f64() - (sample_delays_ms[i] as f64 / 1000.0); + let weight = f64::exp(-sample_age_secs / TIME_CONSTANT_SECS); + + expected_weighted_sum += sample_values[i] * weight; + expected_total_weight += weight; + + println!( + "Sample {}: value={}, age={:.3}s, weight={:.6}", + i, sample_values[i], sample_age_secs, weight + ); + } + + let expected_average = + (expected_weighted_sum / expected_total_weight) * f64::exp(-0.1 / TIME_CONSTANT_SECS); // 0.1s is the time elapsed since the last sample + let actual_average = tracker.get_decayed_time_weighted_average(); + + println!("Expected average: {:.6}", expected_average); + println!("Actual average: {:.6}", actual_average); + println!( + "Difference: {:.6}", + (actual_average - expected_average).abs() + ); + println!( + "Relative error: {:.4}%", + 100.0 * (actual_average - expected_average).abs() / expected_average + ); + + // Verify the calculation is mathematically correct within tolerance + let relative_error = (actual_average - expected_average).abs() / expected_average; + assert!( + relative_error < EPSILON, + "Time-weighted average calculation error too large: expected {:.6}, got {:.6}, relative error {:.4}%", + expected_average, actual_average, relative_error * 100.0 + ); + + // Additional verification: more recent samples should have higher influence + // Sample 3 (400.0) is most recent, so if we compare with a simple average: + let simple_average = sample_values.iter().sum::() / sample_values.len() as f64; + println!("Simple average: {:.6}", simple_average); + + // The time-weighted average should be closer to the most recent value (400.0) + // than the simple average, since recent samples have higher weights + let distance_to_recent = (actual_average - 400.0).abs(); + let distance_simple_to_recent = (simple_average - 400.0).abs(); + + assert!( + distance_to_recent < distance_simple_to_recent, + "Time-weighted average should be closer to recent values than simple average: \ + weighted_avg={:.2}, simple_avg={:.2}, recent_value=400.0", + actual_average, + simple_average + ); + } + + #[test] + fn test_exponential_decay_verification() { + const TIME_CONSTANT_SECS: f64 = 1.0; // 1 second time constant + const EPSILON: f64 = 0.02; // 2% tolerance + + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Record a high value, then wait exactly one time constant + tracker.record_value(1000.0); + thread::sleep(Duration::from_millis(1_000)); // Wait 1 second = 1 time constant + + // Record a low value + tracker.record_value(100.0); + + let actual_average = tracker.get_decayed_time_weighted_average(); + + // After 1 time constant, the old sample should have weight = e^(-1) ≈ 0.368 + // New sample has weight ≈ 1.0 + let old_weight = f64::exp(-1.0); // ≈ 0.368 + let new_weight = 1.0; + + let expected_average = + (1000.0 * old_weight + 100.0 * new_weight) / (old_weight + new_weight); + + println!("Old weight (e^-1): {:.6}", old_weight); + println!("New weight: {:.6}", new_weight); + println!("Expected average: {:.6}", expected_average); + println!("Actual average: {:.6}", actual_average); + + let relative_error = (actual_average - expected_average).abs() / expected_average; + assert!( + relative_error < EPSILON, + "Exponential decay verification failed: expected {:.6}, got {:.6}, error {:.4}%", + expected_average, + actual_average, + relative_error * 100.0 + ); + + // Verify the theoretical calculation: should be around 463.4 + assert!( + (expected_average - 342.04727).abs() < 1e-5, + "Theoretical calculation seems wrong: {:.1}", + expected_average + ); + } + + #[test] + fn test_mathematical_properties() { + const TIME_CONSTANT_SECS: f64 = 2.0; + + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Property 1: Single sample should return its own value + tracker.record_value(42.0); + let single_avg = tracker.get_decayed_time_weighted_average(); + assert!( + (single_avg - 42.0).abs() < 1e-5, + "Single sample average should equal sample value: {}", + single_avg + ); + + // Property 2: Adding identical samples should not change average + tracker.record_value(42.0); + tracker.record_value(42.0); + let identical_avg = tracker.get_decayed_time_weighted_average(); + assert!( + (identical_avg - 42.0).abs() < 1e-5, + "Identical samples should maintain average: {}", + identical_avg + ); + + // Property 3: Average should be bounded by min and max values + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + let values = vec![10.0, 50.0, 30.0, 70.0, 20.0]; + let min_val = 10.0; + let max_val = 70.0; + + for &val in &values { + tracker.record_value(val); + thread::sleep(Duration::from_millis(10)); + } + + let bounded_avg = tracker.get_decayed_time_weighted_average(); + assert!( + bounded_avg >= min_val && bounded_avg <= max_val, + "Average should be bounded: {:.2} not in [{:.2}, {:.2}]", + bounded_avg, + min_val, + max_val + ); + + println!("Values: {:?}", values); + println!( + "Average: {:.2} ∈ [{:.2}, {:.2}] ✓", + bounded_avg, min_val, max_val + ); + } + + #[test] + fn test_concurrent_access_simulation() { + const NUM_THREADS: usize = 10; + const NUM_RECORDS: usize = 100; + + const SLEEP_INTERVAL: usize = 10; + const SLEEP_DURATION_MS: u64 = 1; + + let config = RateLimiterConfig { + ttft_threshold_secs: 1.0, + itl_threshold_secs: 0.1, + time_constant_secs: 30.0, + per_model_limits: false, + is_enabled: true, + }; + let limiter = Arc::new(RateLimiter::new(config)); + let error_count = Arc::new(AtomicUsize::new(0)); + + let mut handles = Vec::new(); + + for i in 0..NUM_THREADS { + let limiter_clone = limiter.clone(); + let error_count_clone = error_count.clone(); + + handles.push(thread::spawn(move || { + for j in 0..NUM_RECORDS { + limiter_clone.record_ttft("model", (i * NUM_RECORDS + j) as f64 / 10_000.0); + limiter_clone.record_itl("model", (i + j) as f64 / 1_000.0); + + if limiter_clone.should_reject("model").should_reject { + error_count_clone.fetch_add(1, Ordering::Relaxed); + } + + if j % SLEEP_INTERVAL == 0 { + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); + } + } + })); + } + + for handle in handles { + handle.join().unwrap(); + } + + // Should have no errors + let error_count = error_count.load(Ordering::Relaxed); + assert_eq!(error_count, 0, "Error count should be 0: {}", error_count); + } + + #[test] + fn test_concurrent_access_simulation_with_error_count() { + const NUM_THREADS: usize = 10; + const NUM_RECORDS: usize = 100; + + const SLEEP_INTERVAL: usize = 10; + const SLEEP_DURATION_MS: u64 = 1; + + let config = RateLimiterConfig { + ttft_threshold_secs: 1.0, + itl_threshold_secs: 0.1, + time_constant_secs: 30.0, + per_model_limits: false, + is_enabled: true, + }; + let limiter = Arc::new(RateLimiter::new(config)); + let error_count = Arc::new(AtomicUsize::new(0)); + + let mut handles = Vec::new(); + + for i in 0..NUM_THREADS { + let limiter_clone = limiter.clone(); + let error_count_clone = error_count.clone(); + + handles.push(thread::spawn(move || { + for j in 0..NUM_RECORDS { + limiter_clone.record_ttft("model", (i * NUM_RECORDS + j) as f64 / 1_000.0); + limiter_clone.record_itl("model", (i + j) as f64 / 100.0); + + if limiter_clone.should_reject("model").should_reject { + error_count_clone.fetch_add(1, Ordering::Relaxed); + } + + if j % SLEEP_INTERVAL == 0 { + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); + } + } + })); + } + + for handle in handles { + handle.join().unwrap(); + } + + // Roughly 10% of the time, we should have an error, we set the threshold to 12% to account for + // the effect of the time passing. + let error_count = error_count.load(Ordering::Relaxed); + assert!( + error_count > 870 && error_count < 930, + "Error count should be around 12% of the time: {}", + error_count + ); + } + + #[test] + fn test_concurrent_operations() { + use std::sync::Mutex; + + const TIME_CONSTANT_SECS: f64 = 10.0; + + const SLEEP_DURATION_MS: u64 = 1; + + const NUM_THREADS: usize = 5; + const NUM_RECORDS: usize = 20; + + let tracker = Arc::new(Mutex::new(TimeWeightedAverageTracker::new( + TIME_CONSTANT_SECS, + ))); + + let mut handles = Vec::new(); + + // Spawn multiple threads adding values + for thread_id in 0..NUM_THREADS { + let tracker_clone = Arc::clone(&tracker); + let handle = thread::spawn(move || { + for i in 0..NUM_RECORDS { + let value = (thread_id * 100 + i) as f64; + tracker_clone.lock().unwrap().record_value(value); + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); + } + }); + handles.push(handle); + } + + // Also spawn a thread that computes averages + const NUM_AVERAGES: usize = 10; + const SLEEP_DURATION_MS_AVERAGE: u64 = 5; + + let tracker_clone = Arc::clone(&tracker); + let avg_handle = thread::spawn(move || { + for _ in 0..NUM_AVERAGES { + let avg = tracker_clone + .lock() + .unwrap() + .get_decayed_time_weighted_average(); + assert!(avg > 0.0, "Average should be positive"); + thread::sleep(Duration::from_millis(SLEEP_DURATION_MS_AVERAGE)); + } + }); + + // Wait for all threads + for handle in handles { + handle.join().unwrap(); + } + avg_handle.join().unwrap(); + + let final_avg = tracker.lock().unwrap().get_decayed_time_weighted_average(); + assert!(final_avg > 0.0, "Final average should be positive"); + } + + #[test] + fn test_rate_limiter_integration() { + let config = RateLimiterConfig { + ttft_threshold_secs: 100., // 100ms + itl_threshold_secs: 1., // 5ms + time_constant_secs: 1.0, + ..Default::default() + }; + + let limiter = Arc::new(RateLimiter::new(config)); + + // Record low values - should not trigger + limiter.record_ttft("test", 50.0); + limiter.record_ttft("test", 60.0); + limiter.record_ttft("test", 70.0); + + thread::sleep(Duration::from_millis(150)); // Wait for warmup + + assert!( + !limiter.should_reject("test").should_reject, + "Should not reject with low values" + ); + + // Record high values - should trigger + limiter.record_ttft("test", 200.0); + limiter.record_ttft("test", 300.0); + + assert!( + limiter.should_reject("test").should_reject, + "Should reject with high values" + ); + } + + #[test] + fn test_rate_limiter_integration_samples_close_to_trigger() { + const NUM_SAMPLES: usize = 100; + + let config = RateLimiterConfig { + ttft_threshold_secs: 70., + itl_threshold_secs: 0.005, + time_constant_secs: 1.0, + ..Default::default() + }; + + let limiter = Arc::new(RateLimiter::new(config)); + + // Record low values - should not trigger + limiter.record_ttft("test", 50.0); + limiter.record_ttft("test", 60.0); + limiter.record_ttft("test", 70.0); + + thread::sleep(Duration::from_millis(150)); // Wait for warmup + + assert!( + !limiter.should_reject("test").should_reject, + "Should not reject with low values" + ); + + // Record multiple values close to trigger + for i in 0..NUM_SAMPLES { + limiter.record_ttft("test", 100.0 + i as f64 / 10.0); + } + + assert!( + limiter.should_reject("test").should_reject, + "Should reject with high values" + ); + } + + #[test] + fn test_per_model_vs_global_limits() { + const MODEL_A: &str = "model_a"; + const MODEL_B: &str = "model_b"; + + let global_config = RateLimiterConfig { + per_model_limits: false, + ..Default::default() + }; + + let per_model_config = RateLimiterConfig { + per_model_limits: true, + ..Default::default() + }; + + let global_limiter = RateLimiter::new(global_config); + let per_model_limiter = RateLimiter::new(per_model_config); + + // Record high values for model A + global_limiter.record_ttft(MODEL_A, 2000.0); + global_limiter.record_ttft(MODEL_A, 2000.0); + + per_model_limiter.record_ttft(MODEL_A, 2000.0); + per_model_limiter.record_ttft(MODEL_A, 2000.0); + + thread::sleep(Duration::from_millis(20)); + + // Both should reject model A + assert!(global_limiter.should_reject(MODEL_A).should_reject); + assert!(per_model_limiter.should_reject(MODEL_A).should_reject); + + // Global limiter should also reject model B (uses same "global" key) + assert!(global_limiter.should_reject(MODEL_B).should_reject); + + // Per-model limiter should NOT reject model B (separate tracking) + assert!(!per_model_limiter.should_reject(MODEL_B).should_reject); + } + + #[test] + fn test_numerical_stability_long_time_series() { + // Scenario 1: Very small time constant with long time series + { + const TIME_CONSTANT_SECS: f64 = 0.01; // Very small time constant + const NUM_SAMPLES: usize = 10_000; + const SLEEP_DURATION_MICROS: u64 = 100; // 0.1ms between samples + + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Add many samples with controlled timing + for i in 0..NUM_SAMPLES { + tracker.record_value((i % 100) as f64); // Cycling values 0-99 + + if i % 1000 == 0 { + // Add occasional small delays to test very old samples + thread::sleep(Duration::from_micros(SLEEP_DURATION_MICROS)); + } + } + + let avg = tracker.get_decayed_time_weighted_average(); + + // Should be finite and reasonable + assert!( + avg.is_finite(), + "Average should be finite with small time constant" + ); + assert!( + avg > 0.0 && avg < 100.0, + "Average should be bounded by sample range: {}", + avg + ); + + // With small time constant, should be dominated by recent samples (90-99 range) + assert!( + (avg - 50.0).abs() < 0.5, + "With small time constant, average should reflect recent samples: {}", + avg + ); + } + + // Scenario 2: Extreme value ranges + { + const TIME_CONSTANT_SECS: f64 = 1.0; + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + let extreme_values = vec![ + 1e-10, // Very small positive + 1e10, // Very large + 1e-15, // Tiny + 1e15, // Huge + 0.001, // Small + 1000000.0, // Large + ]; + + for &value in &extreme_values { + tracker.record_value(value); + thread::sleep(Duration::from_millis(1)); + } + + let avg = tracker.get_decayed_time_weighted_average(); + assert!( + avg.is_finite(), + "Average should handle extreme values gracefully" + ); + assert!(avg > 0.0, "Average of positive values should be positive"); + } + + // Scenario 3: Accumulated precision test with repetitive operations + { + const TIME_CONSTANT_SECS: f64 = 10.0; + const NUM_ITERATIONS: usize = 50_000; + + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Record the same value many times to test for accumulated rounding errors + let test_value = 42.42424242424242; // Value with many decimal places + + for _ in 0..NUM_ITERATIONS { + tracker.record_value(test_value); + } + + let avg = tracker.get_decayed_time_weighted_average(); + + // Should be very close to the test value (within reasonable floating point precision) + let relative_error = (avg - test_value).abs() / test_value; + assert!( + relative_error < 1e-6, + "Accumulated rounding error too large: {} vs {}, error: {:.2e}", + avg, + test_value, + relative_error + ); + } + + // Scenario 4: Weight underflow protection + { + const TIME_CONSTANT_SECS: f64 = 0.1; // Small time constant + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Add initial sample + tracker.record_value(1000.0); + + // Wait a very long time (relative to time constant) + thread::sleep(Duration::from_millis(2000)); // 20 time constants + + // Add recent samples - old sample should have negligible weight + for i in 0..100 { + tracker.record_value(100.0 + i as f64); + thread::sleep(Duration::from_micros(100)); + } + + let avg = tracker.get_decayed_time_weighted_average(); + + // Should be dominated by recent samples, not the old high value + assert!( + avg < 500.0, + "Very old samples should have negligible impact: {}", + avg + ); + assert!(avg > 100.0, "Average should still be reasonable: {}", avg); + } + + // Scenario 5: Monotonic behavior verification + { + const TIME_CONSTANT_SECS: f64 = 5.0; + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Add samples in strictly increasing order + let mut previous_avg = 0.0; + for i in 1..=1000 { + tracker.record_value(i as f64); + + if i % 100 == 0 { + let current_avg = tracker.get_decayed_time_weighted_average(); + + // Average should generally increase when adding larger values + assert!( + current_avg > previous_avg, + "Average should increase with larger values: {} -> {} at iteration {}", + previous_avg, + current_avg, + i + ); + + previous_avg = current_avg; + thread::sleep(Duration::from_millis(1)); + } + } + } + + // Scenario 6: Stability under rapid updates + { + const TIME_CONSTANT_SECS: f64 = 2.0; + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Rapidly add many samples without any delays + for i in 0..100_000 { + tracker.record_value((i % 10) as f64); // Values 0-9 + } + + let avg = tracker.get_decayed_time_weighted_average(); + + assert!(avg.is_finite(), "Rapid updates should maintain stability"); + assert!( + (0.0..=9.0).contains(&avg), + "Average should be bounded: {}", + avg + ); + + // Should be close to the mean of 0-9 = 4.5 + assert!( + (avg - 4.5).abs() < 1.0, + "Average should be close to sample mean with rapid updates: {}", + avg + ); + } + + // Scenario 7: Verify internal weight tracking remains stable + { + const TIME_CONSTANT_SECS: f64 = 1.0; + let mut tracker = TimeWeightedAverageTracker::new(TIME_CONSTANT_SECS); + + // Add samples over a long period to test weight accumulation + for i in 0..1000 { + tracker.record_value(50.0); // Constant value + + if i % 100 == 0 { + thread::sleep(Duration::from_millis(100)); + } + } + + // Internal weights should be reasonable (not infinite or zero) + let avg = tracker.get_decayed_time_weighted_average(); + assert!(avg.is_finite(), "Internal state should remain stable"); + assert!( + (avg - 50.0).abs() < 1e-4, + "Constant values should maintain constant average" + ); + + // Test that the tracker can still respond to new values + tracker.record_value(100.0); + let new_avg = tracker.get_decayed_time_weighted_average(); + assert!( + new_avg > 50.0, + "Tracker should still respond to new values after long series" + ); + } + + println!("✓ All numerical stability tests passed"); + } +} diff --git a/lib/llm/src/http/service/service_v2.rs b/lib/llm/src/http/service/service_v2.rs index 0b2af7763c..016269c755 100644 --- a/lib/llm/src/http/service/service_v2.rs +++ b/lib/llm/src/http/service/service_v2.rs @@ -9,6 +9,8 @@ use super::metrics; use super::Metrics; use super::RouteDoc; use crate::discovery::ModelManager; +use crate::http::service::rate_limiter::RateLimiter; +use crate::http::service::rate_limiter::RateLimiterConfig; use crate::request_template::RequestTemplate; use anyhow::Result; use derive_builder::Builder; @@ -19,13 +21,15 @@ use tokio_util::sync::CancellationToken; pub struct State { metrics: Arc, manager: Arc, + rate_limiter: Arc, } impl State { - pub fn new(manager: Arc) -> Self { + pub fn new(manager: Arc, rate_limiter: Arc) -> Self { Self { manager, metrics: Arc::new(Metrics::default()), + rate_limiter, } } @@ -42,6 +46,14 @@ impl State { self.manager.clone() } + pub fn rate_limiter(&self) -> &RateLimiter { + Arc::as_ref(&self.rate_limiter) + } + + pub fn rate_limiter_clone(&self) -> Arc { + self.rate_limiter.clone() + } + // TODO pub fn sse_keep_alive(&self) -> Option { None @@ -84,6 +96,9 @@ pub struct HttpServiceConfig { #[builder(default = "None")] request_template: Option, + + #[builder(default = "RateLimiterConfig::empty()")] + rate_limiter_config: RateLimiterConfig, } impl HttpService { @@ -155,7 +170,8 @@ impl HttpServiceConfigBuilder { let config: HttpServiceConfig = self.build_internal()?; let model_manager = Arc::new(ModelManager::new()); - let state = Arc::new(State::new(model_manager)); + let rate_limiter = Arc::new(RateLimiter::new(config.rate_limiter_config)); + let state = Arc::new(State::new(model_manager, rate_limiter)); // enable prometheus metrics let registry = metrics::Registry::new(); diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index c32ca25bdb..a70612e948 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -16,6 +16,7 @@ use dynamo_runtime::{ use crate::discovery::ModelEntry; use crate::entrypoint::RouterConfig; +use crate::http::service::rate_limiter::RateLimiterConfig; use crate::model_card::{self, ModelDeploymentCard}; use crate::model_type::ModelType; use crate::request_template::RequestTemplate; @@ -46,6 +47,7 @@ pub struct LocalModelBuilder { router_config: Option, kv_cache_block_size: u32, http_port: u16, + rate_limiter_config: Option, migration_limit: u32, } @@ -61,6 +63,7 @@ impl Default for LocalModelBuilder { context_length: Default::default(), template_file: Default::default(), router_config: Default::default(), + rate_limiter_config: Default::default(), migration_limit: Default::default(), } } @@ -114,6 +117,11 @@ impl LocalModelBuilder { self } + pub fn rate_limiter_config(&mut self, rate_limiter_config: RateLimiterConfig) -> &mut Self { + self.rate_limiter_config = Some(rate_limiter_config); + self + } + pub fn migration_limit(&mut self, migration_limit: Option) -> &mut Self { self.migration_limit = migration_limit.unwrap_or(0); self @@ -155,6 +163,7 @@ impl LocalModelBuilder { template, http_port: self.http_port, router_config: self.router_config.take().unwrap_or_default(), + rate_limiter_config: self.rate_limiter_config.take().unwrap_or_default(), }); } @@ -212,6 +221,7 @@ impl LocalModelBuilder { template, http_port: self.http_port, router_config: self.router_config.take().unwrap_or_default(), + rate_limiter_config: self.rate_limiter_config.take().unwrap_or_default(), }) } } @@ -224,6 +234,7 @@ pub struct LocalModel { template: Option, http_port: u16, // Only used if input is HTTP server router_config: RouterConfig, + rate_limiter_config: RateLimiterConfig, // Only used if input is HTTP server } impl LocalModel { @@ -255,6 +266,10 @@ impl LocalModel { &self.router_config } + pub fn rate_limiter_config(&self) -> RateLimiterConfig { + self.rate_limiter_config.clone() + } + pub fn is_gguf(&self) -> bool { // GGUF is the only file (not-folder) we accept, so we don't need to check the extension // We will error when we come to parse it diff --git a/lib/llm/tests/http-service.rs b/lib/llm/tests/http-service.rs index 982e4aa5da..d475ad4217 100644 --- a/lib/llm/tests/http-service.rs +++ b/lib/llm/tests/http-service.rs @@ -23,6 +23,7 @@ use dynamo_llm::http::{ service::{ error::HttpError, metrics::{Endpoint, RequestType, Status}, + rate_limiter::RateLimiterConfig, service_v2::HttpService, Metrics, }, @@ -234,6 +235,7 @@ fn compute_index(endpoint: &Endpoint, request_type: &RequestType, status: &Statu let status = match status { Status::Success => 0, Status::Error => 1, + Status::Rejected => 2, }; endpoint * 4 + request_type * 2 + status @@ -1259,3 +1261,626 @@ async fn test_request_id_annotation() { cancel_token.cancel(); task.await.unwrap().unwrap(); } + +// === Rate Limiting Tests === + +/// Engine that simulates low TTFT to trigger rate limiting +struct SlowTTFTEngine { + ttft_delay_ms: u64, +} + +#[async_trait] +impl + AsyncEngine< + SingleIn, + ManyOut>, + Error, + > for SlowTTFTEngine +{ + async fn generate( + &self, + request: SingleIn, + ) -> Result>, Error> { + let (request, context) = request.transfer(()); + let ctx = context.context(); + + let generator = request.response_generator(); + let ttft_delay_ms = self.ttft_delay_ms; + + let stream = stream! { + // Simulate slow TTFT + tokio::time::sleep(std::time::Duration::from_millis(ttft_delay_ms)).await; + + // Generate a few tokens with normal ITL + for i in 0..3 { + let inner = generator.create_choice(i, Some(format!("token {i}")), None, None); + let output = NvCreateChatCompletionStreamResponse { inner }; + yield Annotated::from_data(output); + + if i < 2 { + tokio::time::sleep(std::time::Duration::from_millis(10)).await; // Normal ITL + } + } + }; + + Ok(ResponseStream::new(Box::pin(stream), ctx)) + } +} + +/// Engine that simulates slow ITL to trigger rate limiting +struct SlowITLEngine { + itl_delay_ms: u64, +} + +#[async_trait] +impl + AsyncEngine< + SingleIn, + ManyOut>, + Error, + > for SlowITLEngine +{ + async fn generate( + &self, + request: SingleIn, + ) -> Result>, Error> { + let (request, context) = request.transfer(()); + let ctx = context.context(); + + let generator = request.response_generator(); + let itl_delay_ms = self.itl_delay_ms; + + let stream = stream! { + // Fast TTFT + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + + // Generate tokens with slow ITL + for i in 0..5 { + let inner = generator.create_choice(i, Some(format!("token {i}")), None, None); + let output = NvCreateChatCompletionStreamResponse { inner }; + yield Annotated::from_data(output); + + if i < 4 { + tokio::time::sleep(std::time::Duration::from_millis(itl_delay_ms)).await; // Slow ITL + } + } + }; + + Ok(ResponseStream::new(Box::pin(stream), ctx)) + } +} + +#[tokio::test] +async fn test_rate_limiting_triggers_correctly() { + // Create rate limiter config with low thresholds for testing + let rate_limiter_config = RateLimiterConfig::new( + 1.0, // TTFT threshold: 1 second + 0.1, // ITL threshold: 100ms + 5.0, // Time constant: 5 seconds + false, // Global rate limiting + ) + .unwrap(); + + let service = HttpService::builder() + .port(8996) + .rate_limiter_config(rate_limiter_config) + .build() + .unwrap(); + + let state = service.state_clone(); + let manager = state.manager(); + + let token = CancellationToken::new(); + let cancel_token = token.clone(); + let task = tokio::spawn(async move { service.run(token.clone()).await }); + + // Add engines with different performance characteristics + let fast_engine = Arc::new(CounterEngine {}); + let slow_ttft_engine = Arc::new(SlowTTFTEngine { + ttft_delay_ms: 1500, + }); // 1.5s TTFT + let slow_itl_engine = Arc::new(SlowITLEngine { itl_delay_ms: 200 }); // 200ms ITL + + manager + .add_chat_completions_model("fast", fast_engine) + .unwrap(); + manager + .add_chat_completions_model("slow_ttft", slow_ttft_engine) + .unwrap(); + manager + .add_chat_completions_model("slow_itl", slow_itl_engine) + .unwrap(); + + let client = reqwest::Client::new(); + let metrics = state.metrics_clone(); + + // Wait for service to be ready + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + // Test 1: Fast model should work fine initially + let request = async_openai::types::CreateChatCompletionRequestArgs::default() + .model("fast") + .messages(vec![ + async_openai::types::ChatCompletionRequestMessage::User( + async_openai::types::ChatCompletionRequestUserMessage { + content: async_openai::types::ChatCompletionRequestUserMessageContent::Text( + "test".to_string(), + ), + name: None, + }, + ), + ]) + .stream(true) + .max_tokens(10_u32) + .build() + .unwrap(); + + let response = client + .post("http://localhost:8996/v1/chat/completions") + .json(&request) + .send() + .await + .unwrap(); + + assert!( + response.status().is_success(), + "Fast model should work initially" + ); + let _ = response.bytes().await.unwrap(); + + // Test 2: Slow TTFT model should trigger rate limiting after a few requests + let mut slow_ttft_request = request.clone(); + slow_ttft_request.model = "slow_ttft".to_string(); + + // Make several requests to build up the EMA + for i in 0..3 { + println!("Sending slow TTFT request {}", i + 1); + let response = client + .post("http://localhost:8996/v1/chat/completions") + .json(&slow_ttft_request) + .send() + .await + .unwrap(); + + // First few requests should succeed (building up EMA) + if i < 2 { + assert!( + response.status().is_success(), + "Slow TTFT request {} should succeed while building EMA", + i + 1 + ); + let _ = response.bytes().await.unwrap(); + } else { + // Later requests should be rate limited + if response.status() == StatusCode::TOO_MANY_REQUESTS { + println!("Rate limiting triggered after {} requests", i + 1); + break; + } else { + let _ = response.bytes().await.unwrap(); + } + } + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; // Small delay between requests + } + + // Test 3: Slow ITL model should also trigger rate limiting + let mut slow_itl_request = request.clone(); + slow_itl_request.model = "slow_itl".to_string(); + + for i in 0..3 { + println!("Sending slow ITL request {}", i + 1); + let response = client + .post("http://localhost:8996/v1/chat/completions") + .json(&slow_itl_request) + .send() + .await + .unwrap(); + + if i < 2 { + assert!( + response.status().is_success(), + "Slow ITL request {} should succeed while building EMA", + i + 1 + ); + let _ = response.bytes().await.unwrap(); + } else { + // Later requests should be rate limited + if response.status() == StatusCode::TOO_MANY_REQUESTS { + println!("ITL rate limiting triggered after {} requests", i + 1); + break; + } else { + let _ = response.bytes().await.unwrap(); + } + } + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + + // Test 4: Verify rejection metrics were recorded + let rejection_count = metrics.get_rate_limit_requests_counter( + "slow_ttft", + &Endpoint::ChatCompletions, + &RequestType::Stream, + &Status::Rejected, + ) + metrics.get_rate_limit_requests_counter( + "slow_itl", + &Endpoint::ChatCompletions, + &RequestType::Stream, + &Status::Rejected, + ); + + println!("Total rejection count: {}", rejection_count); + + cancel_token.cancel(); + task.await.unwrap().unwrap(); +} + +#[tokio::test] +async fn test_rate_limiting_http_integration() { + let rate_limiter_config = RateLimiterConfig::new(0.1, 0.01, 5.0, false).unwrap(); + let service = HttpService::builder() + .port(8997) + .rate_limiter_config(rate_limiter_config) + .build() + .unwrap(); + + let state = service.state_clone(); + let manager = state.manager(); + + let token = CancellationToken::new(); + let cancel_token = token.clone(); + let task = tokio::spawn(async move { service.run(token.clone()).await }); + + // Use simple CounterEngine (already exists) + let engine = Arc::new(CounterEngine {}); + manager.add_chat_completions_model("test", engine).unwrap(); + + // Manually record high TTFT values to trigger rate limiting + state.rate_limiter().record_ttft("test", 0.5); + state.rate_limiter().record_ttft("test", 0.3); + state.rate_limiter().record_ttft("test", 0.4); + + let client = reqwest::Client::new(); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let request = async_openai::types::CreateChatCompletionRequestArgs::default() + .model("test") + .messages(vec![ + async_openai::types::ChatCompletionRequestMessage::User( + async_openai::types::ChatCompletionRequestUserMessage { + content: async_openai::types::ChatCompletionRequestUserMessageContent::Text( + "test".to_string(), + ), + name: None, + }, + ), + ]) + .stream(true) + .max_tokens(3_u32) + .build() + .unwrap(); + + // This request should be rate limited + let response = client + .post("http://localhost:8997/v1/chat/completions") + .json(&request) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS); + println!("✅ Rate limiting triggered correctly!"); + + // Verify metrics were recorded + let rejection_count = state.metrics_clone().get_rate_limit_requests_counter( + "test", + &Endpoint::ChatCompletions, + &RequestType::Stream, + &Status::Rejected, + ); + assert!(rejection_count > 0, "Should have recorded rejection"); + + cancel_token.cancel(); + task.await.unwrap().unwrap(); +} + +#[tokio::test] +async fn test_per_model_vs_global_rate_limiting() { + // Test global rate limiting (per_model_limits = false) + let global_config = RateLimiterConfig::new(0.8, 0.08, 3.0, false).unwrap(); + let service1 = HttpService::builder() + .port(8998) + .rate_limiter_config(global_config) + .build() + .unwrap(); + + let state1 = service1.state_clone(); + let manager1 = state1.manager(); + + let token1 = CancellationToken::new(); + let cancel_token1 = token1.clone(); + let task1 = tokio::spawn(async move { service1.run(token1.clone()).await }); + + // Test per-model rate limiting (per_model_limits = true) + let per_model_config = RateLimiterConfig::new(0.8, 0.08, 3.0, true).unwrap(); + let service2 = HttpService::builder() + .port(8993) + .rate_limiter_config(per_model_config) + .build() + .unwrap(); + + let state2 = service2.state_clone(); + let manager2 = state2.manager(); + + let token2 = CancellationToken::new(); + let cancel_token2 = token2.clone(); + let task2 = tokio::spawn(async move { service2.run(token2.clone()).await }); + + // Add slow engines to both services + let slow_engine1a = Arc::new(SlowTTFTEngine { + ttft_delay_ms: 1200, + }); + let slow_engine1b = Arc::new(SlowTTFTEngine { + ttft_delay_ms: 1200, + }); + let slow_engine2a = Arc::new(SlowTTFTEngine { + ttft_delay_ms: 1200, + }); + let slow_engine2b = Arc::new(SlowTTFTEngine { + ttft_delay_ms: 1200, + }); + + manager1 + .add_chat_completions_model("model_a", slow_engine1a) + .unwrap(); + manager1 + .add_chat_completions_model("model_b", slow_engine1b) + .unwrap(); + manager2 + .add_chat_completions_model("model_a", slow_engine2a) + .unwrap(); + manager2 + .add_chat_completions_model("model_b", slow_engine2b) + .unwrap(); + + let client = reqwest::Client::new(); + + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + + let base_request = async_openai::types::CreateChatCompletionRequestArgs::default() + .messages(vec![ + async_openai::types::ChatCompletionRequestMessage::User( + async_openai::types::ChatCompletionRequestUserMessage { + content: async_openai::types::ChatCompletionRequestUserMessageContent::Text( + "test".to_string(), + ), + name: None, + }, + ), + ]) + .stream(true) + .max_tokens(10_u32) + .build() + .unwrap(); + + // Test global rate limiting - model_a affects model_b + println!("Testing global rate limiting..."); + for _i in 0..3 { + let mut request_a = base_request.clone(); + request_a.model = "model_a".to_string(); + + let response = client + .post("http://localhost:8998/v1/chat/completions") + .json(&request_a) + .send() + .await + .unwrap(); + + if response.status().is_success() { + let _ = response.bytes().await.unwrap(); + } + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + + // Now model_b should be affected by model_a's rate limiting (global) + let mut request_b = base_request.clone(); + request_b.model = "model_b".to_string(); + + let response = client + .post("http://localhost:8998/v1/chat/completions") + .json(&request_b) + .send() + .await + .unwrap(); + + println!( + "Global rate limiting - model_b status: {}", + response.status() + ); + + // Test per-model rate limiting - model_a doesn't affect model_b + println!("Testing per-model rate limiting..."); + for _i in 0..3 { + let mut request_a = base_request.clone(); + request_a.model = "model_a".to_string(); + + let response = client + .post("http://localhost:8993/v1/chat/completions") + .json(&request_a) + .send() + .await + .unwrap(); + + if response.status().is_success() { + let _ = response.bytes().await.unwrap(); + } + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + + // model_b should NOT be affected by model_a's rate limiting (per-model) + let mut request_b2 = base_request.clone(); + request_b2.model = "model_b".to_string(); + + let response = client + .post("http://localhost:8998/v1/chat/completions") + .json(&request_b2) + .send() + .await + .unwrap(); + + println!( + "Per-model rate limiting - model_b status: {}", + response.status() + ); + // Model B should succeed since it has its own rate limiting state + assert!( + response.status().is_success(), + "Per-model rate limiting should not affect model_b" + ); + + if response.status().is_success() { + let _ = response.bytes().await.unwrap(); + } + + cancel_token1.cancel(); + cancel_token2.cancel(); + task1.await.unwrap().unwrap(); + task2.await.unwrap().unwrap(); +} + +#[tokio::test] +async fn test_rate_limiting_recovery() { + let rate_limiter_config = RateLimiterConfig::new( + 0.6, // TTFT threshold: 600ms + 0.06, // ITL threshold: 60ms + 1.0, // Short time constant for faster recovery + false, + ) + .unwrap(); + + let service = HttpService::builder() + .port(8999) + .rate_limiter_config(rate_limiter_config) + .build() + .unwrap(); + + let state = service.state_clone(); + let manager = state.manager(); + + let token = CancellationToken::new(); + let cancel_token = token.clone(); + let task = tokio::spawn(async move { service.run(token.clone()).await }); + + // Add engines with different speeds + let slow_engine = Arc::new(SlowTTFTEngine { + ttft_delay_ms: 1000, + }); // 1s TTFT + let fast_engine = Arc::new(CounterEngine {}); + + manager + .add_chat_completions_model("slow", slow_engine) + .unwrap(); + manager + .add_chat_completions_model("fast", fast_engine) + .unwrap(); + + let client = reqwest::Client::new(); + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let slow_request = async_openai::types::CreateChatCompletionRequestArgs::default() + .model("slow") + .messages(vec![ + async_openai::types::ChatCompletionRequestMessage::User( + async_openai::types::ChatCompletionRequestUserMessage { + content: async_openai::types::ChatCompletionRequestUserMessageContent::Text( + "test".to_string(), + ), + name: None, + }, + ), + ]) + .stream(true) + .max_tokens(10_u32) + .build() + .unwrap(); + + let fast_request = async_openai::types::CreateChatCompletionRequestArgs::default() + .model("fast") + .messages(vec![ + async_openai::types::ChatCompletionRequestMessage::User( + async_openai::types::ChatCompletionRequestUserMessage { + content: async_openai::types::ChatCompletionRequestUserMessageContent::Text( + "test".to_string(), + ), + name: None, + }, + ), + ]) + .stream(true) + .max_tokens(10_u32) + .build() + .unwrap(); + + // Phase 1: Trigger rate limiting with slow requests + println!("Phase 1: Triggering rate limiting..."); + for i in 0..4 { + let response = client + .post("http://localhost:8999/v1/chat/completions") + .json(&slow_request) + .send() + .await + .unwrap(); + + if response.status() == StatusCode::TOO_MANY_REQUESTS { + println!("Rate limiting triggered at request {}", i + 1); + break; + } else { + let _ = response.bytes().await.unwrap(); + } + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + + // Phase 2: Wait for system to recover (time constant is 1.0s) + println!("Phase 2: Waiting for recovery..."); + tokio::time::sleep(std::time::Duration::from_millis(3000)).await; // Wait 3 time constants + + // Phase 3: Send fast requests to bring down the EMA + println!("Phase 3: Sending fast requests to improve EMA..."); + for i in 0..3 { + let response = client + .post("http://localhost:8999/v1/chat/completions") + .json(&fast_request) + .send() + .await + .unwrap(); + + println!("Fast request {} status: {}", i + 1, response.status()); + if response.status().is_success() { + let _ = response.bytes().await.unwrap(); + } + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + } + + // Phase 4: Verify that slow requests work again (system recovered) + println!("Phase 4: Testing recovery with moderate request..."); + let response = client + .post("http://localhost:8999/v1/chat/completions") + .json(&fast_request) + .send() + .await + .unwrap(); + + println!("Recovery test status: {}", response.status()); + assert!( + response.status().is_success(), + "System should have recovered and accept requests again" + ); + + if response.status().is_success() { + let _ = response.bytes().await.unwrap(); + } + + cancel_token.cancel(); + task.await.unwrap().unwrap(); +}