From 9b75bc0b06a315e1039f0d991e7494499926312c Mon Sep 17 00:00:00 2001 From: kerthcet Date: Sun, 4 Jan 2026 11:37:06 +0800 Subject: [PATCH] make client immutate Signed-off-by: kerthcet --- README.md | 2 +- examples/wrr.rs | 2 +- src/client/client.rs | 4 ++-- src/router/random.rs | 2 +- src/router/router.rs | 2 +- src/router/wrr.rs | 25 +++++++++++--------- tests/client.rs | 56 ++++++++++++++++++++++---------------------- 7 files changed, 48 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index ad36908..5af988f 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ fn main() { .build() .unwrap(); - let mut client = client::Client::new(config); + let client = client::Client::new(config); let request = chat::CreateChatCompletionRequestArgs::default() .messages([ chat::ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), diff --git a/examples/wrr.rs b/examples/wrr.rs index 63d5767..3e7bfa7 100644 --- a/examples/wrr.rs +++ b/examples/wrr.rs @@ -23,7 +23,7 @@ fn main() { .build() .unwrap(); - let mut client = client::Client::new(config); + let client = client::Client::new(config); let request = chat::CreateChatCompletionRequestArgs::default() .messages([ chat::ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), diff --git a/src/client/client.rs b/src/client/client.rs index f465f1e..f00dcb9 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -29,7 +29,7 @@ impl Client { } pub async fn create_response( - &mut self, + &self, request: responses::CreateResponse, ) -> Result { let candidate = self.router.sample(); @@ -39,7 +39,7 @@ impl Client { // This is chat completion endpoint. pub async fn create_completion( - &mut self, + &self, request: chat::CreateChatCompletionRequest, ) -> Result { let candidate = self.router.sample(); diff --git a/src/router/random.rs b/src/router/random.rs index 1ca2221..fc131c4 100644 --- a/src/router/random.rs +++ b/src/router/random.rs @@ -18,7 +18,7 @@ impl Router for RandomRouter { "RandomRouter" } - fn sample(&mut self) -> ModelName { + fn sample(&self) -> ModelName { let mut rng = rand::rng(); let idx = rng.random_range(0..self.model_infos.len()); self.model_infos[idx].name.clone() diff --git a/src/router/router.rs b/src/router/router.rs index 783db42..68a500c 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -24,7 +24,7 @@ pub fn construct_router(mode: RouterMode, models: Vec) -> Box &'static str; - fn sample(&mut self) -> ModelName; + fn sample(&self) -> ModelName; } #[cfg(test)] diff --git a/src/router/wrr.rs b/src/router/wrr.rs index cf9c416..b96d86d 100644 --- a/src/router/wrr.rs +++ b/src/router/wrr.rs @@ -1,3 +1,5 @@ +use std::sync::atomic::AtomicI32; + use crate::client::config::ModelName; use crate::router::router::{ModelInfo, Router}; @@ -5,7 +7,7 @@ pub struct WeightedRoundRobinRouter { total_weight: i32, model_infos: Vec, // current_weight is ordered by model_infos index. - current_weights: Vec, + current_weights: Vec, } impl WeightedRoundRobinRouter { @@ -16,7 +18,7 @@ impl WeightedRoundRobinRouter { Self { model_infos: model_infos, total_weight: total_weight, - current_weights: vec![0; length], + current_weights: (0..length).map(|_| AtomicI32::new(0)).collect(), } } } @@ -27,27 +29,28 @@ impl Router for WeightedRoundRobinRouter { } // Use Smooth Weighted Round Robin Algorithm. - fn sample(&mut self) -> ModelName { + fn sample(&self) -> ModelName { // return early if only one model. if self.model_infos.len() == 1 { return self.model_infos[0].name.clone(); } - self.current_weights - .iter_mut() - .enumerate() - .for_each(|(i, weight)| { - *weight += self.model_infos[i].weight; - }); + // 1. add weight to current weight. + self.model_infos.iter().enumerate().for_each(|(i, weight)| { + self.current_weights[i].fetch_add(weight.weight, std::sync::atomic::Ordering::Relaxed); + }); let mut max_index = 0; for i in 1..self.current_weights.len() { - if self.current_weights[i] > self.current_weights[max_index] { + if self.current_weights[i].load(std::sync::atomic::Ordering::Relaxed) + > self.current_weights[max_index].load(std::sync::atomic::Ordering::Relaxed) + { max_index = i; } } - self.current_weights[max_index] -= self.total_weight; + self.current_weights[max_index] + .fetch_sub(self.total_weight, std::sync::atomic::Ordering::Relaxed); self.model_infos[max_index].name.clone() } } diff --git a/tests/client.rs b/tests/client.rs index f30d1b5..c46f151 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -8,6 +8,31 @@ use arms::types::responses; mod tests { use super::*; + #[tokio::test] + async fn test_completion() { + from_filename(".env.integration-test").ok(); + + let config = client::Config::builder() + .provider("faker") + .model( + client::ModelConfig::builder() + .name("fake-completion-model") + .build() + .unwrap(), + ) + .build() + .unwrap(); + + let client = client::Client::new(config); + let request = chat::CreateChatCompletionRequestArgs::default() + .build() + .unwrap(); + + let response = client.create_completion(request).await.unwrap(); + assert!(response.id.starts_with("fake-completion-id")); + assert!(response.model == "fake-completion-model"); + } + #[tokio::test] async fn test_response() { from_filename(".env.integration-test").ok(); @@ -24,7 +49,7 @@ mod tests { .build() .unwrap(); - let mut client = client::Client::new(config); + let client = client::Client::new(config); let request = responses::CreateResponseArgs::default() .input("tell me the weather today") .build() @@ -45,7 +70,7 @@ mod tests { ) .build() .unwrap(); - let mut client = client::Client::new(config); + let client = client::Client::new(config); let request = responses::CreateResponseArgs::default() .model("gpt-3.5-turbo") .input("tell me a joke") @@ -74,36 +99,11 @@ mod tests { ) .build() .unwrap(); - let mut client = client::Client::new(config); + let client = client::Client::new(config); let request = responses::CreateResponseArgs::default() .input("give me a poem about nature") .build() .unwrap(); let _ = client.create_response(request).await.unwrap(); } - - #[tokio::test] - async fn test_completion() { - from_filename(".env.integration-test").ok(); - - let config = client::Config::builder() - .provider("faker") - .model( - client::ModelConfig::builder() - .name("fake-completion-model") - .build() - .unwrap(), - ) - .build() - .unwrap(); - - let mut client = client::Client::new(config); - let request = chat::CreateChatCompletionRequestArgs::default() - .build() - .unwrap(); - - let response = client.create_completion(request).await.unwrap(); - assert!(response.id.starts_with("fake-completion-id")); - assert!(response.model == "fake-completion-model"); - } }