From 2fbe2b6ed9e112c2c4f4ec41c04d7c29bc74ab8b Mon Sep 17 00:00:00 2001 From: clearloop Date: Thu, 18 Dec 2025 01:28:18 +0800 Subject: [PATCH] feat(ullm): introduce tool calls --- Cargo.lock | 14 +- Cargo.toml | 7 +- crates/cli/Cargo.toml | 22 +++ crates/{ullm/src => cli}/bin/ullm.rs | 2 +- crates/{ullm/src/cmd => cli/src}/chat.rs | 6 +- crates/{ullm/src/cmd => cli/src}/config.rs | 6 +- .../{ullm/src/cmd/mod.rs => cli/src/lib.rs} | 0 crates/core/Cargo.toml | 6 +- crates/core/src/agent.rs | 52 ++++++ crates/core/src/chat.rs | 129 +++++++++++++- crates/core/src/config.rs | 97 +++-------- crates/core/src/lib.rs | 6 +- crates/core/src/message.rs | 9 + crates/core/src/provider.rs | 15 +- crates/core/src/response.rs | 14 ++ crates/core/src/stream.rs | 16 +- crates/core/src/template.rs | 37 ---- crates/core/src/tool.rs | 10 +- crates/ullm/Cargo.toml | 11 -- crates/ullm/src/lib.rs | 6 +- llm/deepseek/src/llm.rs | 26 +-- llm/deepseek/src/request.rs | 160 ++++++++---------- 22 files changed, 372 insertions(+), 279 deletions(-) create mode 100644 crates/cli/Cargo.toml rename crates/{ullm/src => cli}/bin/ullm.rs (88%) rename crates/{ullm/src/cmd => cli/src}/chat.rs (96%) rename crates/{ullm/src/cmd => cli/src}/config.rs (90%) rename crates/{ullm/src/cmd/mod.rs => cli/src/lib.rs} (100%) create mode 100644 crates/core/src/agent.rs delete mode 100644 crates/core/src/template.rs diff --git a/Cargo.lock b/Cargo.lock index 65e7513..8a72edd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2950,18 +2950,24 @@ dependencies = [ [[package]] name = "ullm" version = "0.0.9" +dependencies = [ + "ullm-core", + "ullm-deepseek", +] + +[[package]] +name = "ullm-cli" +version = "0.0.9" dependencies = [ "anyhow", "clap", "dirs", "futures-util", "serde", - "tokio", "toml", "tracing", "tracing-subscriber", - "ullm-core", - "ullm-deepseek", + "ullm", ] [[package]] @@ -2969,8 +2975,10 @@ name = "ullm-core" version = "0.0.9" dependencies = [ "anyhow", + "async-stream", "derive_more", "futures-core", + "futures-util", "reqwest", "schemars", "serde", diff --git a/Cargo.toml b/Cargo.toml index 6e8efad..c7101e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,11 +12,12 @@ documentation = "https://cydonia.docs.rs" keywords = ["llm", "agent", "ai"] [workspace.dependencies] -model = { path = "legacy/model", package = "cydonia-model" } candle = { path = "crates/candle", package = "cydonia-candle" } -ucore = { path = "crates/core", package = "ullm-core" } deepseek = { path = "llm/deepseek", package = "ullm-deepseek" } - +model = { path = "legacy/model", package = "cydonia-model" } +ullm = { path = "crates/ullm" } +ucore = { path = "crates/core", package = "ullm-core" } +ucli = { path = "crates/cli", package = "ullm-cli" } # crates.io anyhow = "1" diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml new file mode 100644 index 0000000..cf55045 --- /dev/null +++ b/crates/cli/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "ullm-cli" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +documentation.workspace = true +keywords.workspace = true + +[dependencies] +ullm.workspace = true + +# crates-io dependencies +anyhow.workspace = true +clap.workspace = true +dirs.workspace = true +futures-util.workspace = true +serde.workspace = true +toml.workspace = true +tracing.workspace = true +tracing-subscriber.workspace = true diff --git a/crates/ullm/src/bin/ullm.rs b/crates/cli/bin/ullm.rs similarity index 88% rename from crates/ullm/src/bin/ullm.rs rename to crates/cli/bin/ullm.rs index 6aef0fa..a0f06e2 100644 --- a/crates/ullm/src/bin/ullm.rs +++ b/crates/cli/bin/ullm.rs @@ -1,6 +1,6 @@ use anyhow::Result; use clap::Parser; -use ullm::cmd::{App, Command, Config}; +use ucli::{App, Command, Config}; #[tokio::main] async fn main() -> Result<()> { diff --git a/crates/ullm/src/cmd/chat.rs b/crates/cli/src/chat.rs similarity index 96% rename from crates/ullm/src/cmd/chat.rs rename to crates/cli/src/chat.rs index 2f27c3d..df02820 100644 --- a/crates/ullm/src/cmd/chat.rs +++ b/crates/cli/src/chat.rs @@ -1,7 +1,6 @@ //! Chat command use super::Config; -use crate::DeepSeek; use anyhow::Result; use clap::{Args, ValueEnum}; use futures_util::StreamExt; @@ -9,7 +8,8 @@ use std::{ fmt::{Display, Formatter}, io::{BufRead, Write}, }; -use ucore::{Chat, Client, LLM, Message}; +use ullm::DeepSeek; +use ullm::{Chat, Client, LLM, Message}; /// Chat command arguments #[derive(Debug, Args)] @@ -64,7 +64,7 @@ impl ChatCmd { Ok(()) } - async fn send(chat: &mut Chat, message: Message, stream: bool) -> Result<()> { + async fn send(chat: &mut Chat, message: Message, stream: bool) -> Result<()> { if stream { let mut response_content = String::new(); { diff --git a/crates/ullm/src/cmd/config.rs b/crates/cli/src/config.rs similarity index 90% rename from crates/ullm/src/cmd/config.rs rename to crates/cli/src/config.rs index 5935d52..6916820 100644 --- a/crates/ullm/src/cmd/config.rs +++ b/crates/cli/src/config.rs @@ -10,7 +10,7 @@ static CONFIG: LazyLock = #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Config { /// The configuration for the CLI - config: ucore::Config, + config: ullm::General, /// The API keys for LLMs pub key: BTreeMap, @@ -31,7 +31,7 @@ impl Config { } /// Get the core config - pub fn config(&self) -> &ucore::Config { + pub fn config(&self) -> &ullm::General { &self.config } } @@ -39,7 +39,7 @@ impl Config { impl Default for Config { fn default() -> Self { Self { - config: ucore::Config::default(), + config: ullm::General::default(), key: [("deepseek".to_string(), "YOUR_API_KEY".to_string())] .into_iter() .collect::<_>(), diff --git a/crates/ullm/src/cmd/mod.rs b/crates/cli/src/lib.rs similarity index 100% rename from crates/ullm/src/cmd/mod.rs rename to crates/cli/src/lib.rs diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index e10ef4b..52fb1e1 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -10,11 +10,11 @@ keywords.workspace = true [dependencies] anyhow.workspace = true +async-stream.workspace = true derive_more.workspace = true serde.workspace = true +serde_json.workspace = true futures-core.workspace = true +futures-util.workspace = true reqwest.workspace = true schemars.workspace = true - -[dev-dependencies] -serde_json.workspace = true diff --git a/crates/core/src/agent.rs b/crates/core/src/agent.rs new file mode 100644 index 0000000..e7d52e7 --- /dev/null +++ b/crates/core/src/agent.rs @@ -0,0 +1,52 @@ +//! Turbofish Agent library + +use crate::{Message, StreamChunk, Tool, ToolCall, ToolChoice, message::ToolMessage}; +use anyhow::Result; + +/// A trait for turbofish agents +/// +/// TODO: add schemar for request and response +pub trait Agent: Clone { + /// The parsed chunk from [StreamChunk] + type Chunk; + + /// The system prompt for the agent + const SYSTEM_PROMPT: &str; + + /// The tools for the agent + const TOOLS: Vec = Vec::new(); + + /// Filter the messages to match required tools for the agent + fn filter(&self, _message: &str) -> ToolChoice { + ToolChoice::Auto + } + + /// Dispatch tool calls + fn dispatch(&self, tools: &[ToolCall]) -> impl Future> { + async move { + tools + .iter() + .map(|tool| ToolMessage { + tool: tool.id.clone(), + message: Message::tool(format!( + "function {} not available", + tool.function.name + )), + }) + .collect() + } + } + + /// Parse a chunk from [StreamChunk] + fn chunk(&self, chunk: &StreamChunk) -> impl Future>; +} + +impl Agent for () { + type Chunk = StreamChunk; + + const SYSTEM_PROMPT: &str = "You are a helpful assistant."; + + async fn chunk(&self, chunk: &StreamChunk) -> Result { + Ok(chunk.clone()) + } +} diff --git a/crates/core/src/chat.rs b/crates/core/src/chat.rs index 4b745fa..1f1e4d7 100644 --- a/crates/core/src/chat.rs +++ b/crates/core/src/chat.rs @@ -1,15 +1,19 @@ //! Chat abstractions for the unified LLM Interfaces use crate::{ - LLM, Response, Role, StreamChunk, + Agent, Config, FinishReason, General, LLM, Response, Role, message::{AssistantMessage, Message, ToolMessage}, }; use anyhow::Result; use futures_core::Stream; +use futures_util::StreamExt; use serde::Serialize; +const MAX_TOOL_CALLS: usize = 16; + /// A chat for the LLM -pub struct Chat { +#[derive(Clone)] +pub struct Chat { /// The chat configuration pub config: P::ChatConfig, @@ -17,20 +21,125 @@ pub struct Chat { pub messages: Vec, /// The LLM provider - pub provider: P, + provider: P, + + /// The agent + agent: A, + + /// Whether to return the usage information in stream mode + usage: bool, +} + +impl Chat { + /// Create a new chat + pub fn new(config: General, provider: P) -> Self { + Self { + messages: vec![], + provider, + usage: config.usage, + agent: (), + config: config.into(), + } + } } -impl Chat

{ +impl Chat { + /// Add the system prompt to the chat + pub fn system(mut self, agent: B) -> Chat { + let mut messages = self.messages; + if messages.is_empty() { + messages.push(Message::system(A::SYSTEM_PROMPT).into()); + } else if let Some(ChatMessage::System(_)) = messages.first() { + messages.insert(0, Message::system(A::SYSTEM_PROMPT).into()); + } else { + messages = vec![Message::system(A::SYSTEM_PROMPT).into()] + .into_iter() + .chain(messages) + .collect(); + } + + self.config = self.config.with_tools(A::TOOLS); + Chat { + messages, + provider: self.provider, + usage: self.usage, + agent, + config: self.config, + } + } + /// Send a message to the LLM pub async fn send(&mut self, message: Message) -> Result { + let config = self + .config + .with_tool_choice(self.agent.filter(message.content.as_str())); self.messages.push(message.into()); - self.provider.send(&self.config, &self.messages).await + + for _ in 0..MAX_TOOL_CALLS { + let response = self.provider.send(&config, &self.messages).await?; + let Some(tool_calls) = response.tool_calls() else { + return Ok(response); + }; + + let result = self.agent.dispatch(tool_calls).await; + self.messages.extend(result.into_iter().map(Into::into)); + } + + anyhow::bail!("max tool calls reached"); } /// Send a message to the LLM with streaming - pub fn stream(&mut self, message: Message) -> impl Stream> { + pub fn stream( + &mut self, + message: Message, + ) -> impl Stream> + use<'_, P, A> { + let config = self + .config + .with_tool_choice(self.agent.filter(message.content.as_str())); self.messages.push(message.into()); - self.provider.stream(&self.config, &self.messages) + + async_stream::try_stream! { + for _ in 0..MAX_TOOL_CALLS { + let messages = self.messages.clone(); + let inner = self.provider.stream(config.clone(), &messages, self.usage); + futures_util::pin_mut!(inner); + + let mut tool_calls = None; + let mut message = String::new(); + while let Some(chunk) = inner.next().await { + let chunk = chunk?; + if let Some(calls) = chunk.tool_calls() { + tool_calls = Some(calls.to_vec()); + } + + if let Some(content) = chunk.content() { + message.push_str(content); + } + + yield self.agent.chunk(&chunk).await?; + if let Some(reason) = chunk.reason() { + match reason { + FinishReason::Stop => return, + FinishReason::ToolCalls => break, + reason => Err(anyhow::anyhow!("unexpected finish reason: {reason:?}"))?, + } + } + } + + if !message.is_empty() { + self.messages.push(Message::assistant(&message).into()); + } + + if let Some(calls) = tool_calls { + let result = self.agent.dispatch(&calls).await; + self.messages.extend(result.into_iter().map(Into::into)); + } else { + break; + } + } + + Err(anyhow::anyhow!("max tool calls reached"))?; + } } } @@ -68,3 +177,9 @@ impl From for ChatMessage { } } } + +impl From for ChatMessage { + fn from(message: ToolMessage) -> Self { + ChatMessage::Tool(message) + } +} diff --git a/crates/core/src/config.rs b/crates/core/src/config.rs index 6e43d7f..062c6ef 100644 --- a/crates/core/src/config.rs +++ b/crates/core/src/config.rs @@ -3,103 +3,48 @@ use crate::{Tool, ToolChoice}; use serde::{Deserialize, Serialize}; +/// LLM configuration +pub trait Config: From + Sized + Clone { + /// Create a new configuration with tools + fn with_tools(self, tools: Vec) -> Self; + + /// Create a new configuration with tool choice + /// + /// This should be used for per-message level. + fn with_tool_choice(&self, tool_choice: ToolChoice) -> Self; +} + /// Chat configuration #[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Config { - /// The frequency penalty of the model - pub frequency: i8, - - /// Whether to response in JSON - pub json: bool, - - /// Whether to return the log probabilities - pub logprobs: bool, - +pub struct General { /// The model to use pub model: String, - /// The presence penalty of the model - pub presence: i8, - - /// Stop sequences to halt generation - pub stop: Vec, - - /// The temperature of the model - pub temperature: f32, - - /// Whether to enable thinking - pub think: bool, - - /// Controls which tool is called by the model - pub tool_choice: ToolChoice, - - /// A list of tools the model may call - pub tools: Vec, - - /// The top probability of the model - pub top_p: f32, - - /// The number of top log probabilities to return - pub top_logprobs: usize, - - /// The number of max tokens to generate - pub tokens: usize, + /// The tools to use + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, /// Whether to return the usage information in stream mode pub usage: bool, } -impl Config { +impl General { /// Create a new configuration pub fn new(model: impl Into) -> Self { Self { model: model.into(), - ..Default::default() + tools: None, + usage: false, } } - - /// Add a tool to the configuration - pub fn tool(mut self, tool: Tool) -> Self { - self.tools.push(tool); - self - } - - /// Set tools for the configuration - pub fn tools(mut self, tools: Vec) -> Self { - self.tools = tools; - self - } - - /// Set the tool choice for the configuration - pub fn tool_choice(mut self, choice: ToolChoice) -> Self { - self.tool_choice = choice; - self - } - - /// Set stop sequences for the configuration - pub fn stop(mut self, sequences: Vec) -> Self { - self.stop = sequences; - self - } } -impl Default for Config { +impl Default for General { fn default() -> Self { Self { - frequency: 0, - json: false, - logprobs: false, model: "deepseek-chat".into(), - presence: 0, - stop: Vec::new(), - temperature: 1.0, - think: false, - tool_choice: ToolChoice::None, - tools: Vec::new(), - top_logprobs: 0, - top_p: 1.0, - tokens: 1000, - usage: true, + tools: None, + usage: false, } } } diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 1e2aa12..17b1424 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -1,8 +1,9 @@ //! Core abstractions for Unified LLM Interface pub use { + agent::Agent, chat::{Chat, ChatMessage}, - config::Config, + config::{Config, General}, message::{Message, Role}, provider::LLM, reqwest::{self, Client}, @@ -11,15 +12,14 @@ pub use { ResponseMessage, TopLogProb, Usage, }, stream::{Delta, StreamChoice, StreamChunk}, - template::Template, tool::{FunctionCall, Tool, ToolCall, ToolChoice}, }; +mod agent; mod chat; mod config; mod message; mod provider; mod response; mod stream; -mod template; mod tool; diff --git a/crates/core/src/message.rs b/crates/core/src/message.rs index c9434a2..85e9f13 100644 --- a/crates/core/src/message.rs +++ b/crates/core/src/message.rs @@ -42,6 +42,15 @@ impl Message { content: content.into(), } } + + /// Create a new tool message + pub fn tool(content: impl Into) -> Self { + Self { + role: Role::Tool, + name: String::new(), + content: content.into(), + } + } } /// A tool message in the chat diff --git a/crates/core/src/provider.rs b/crates/core/src/provider.rs index 6fbbe91..b2f5010 100644 --- a/crates/core/src/provider.rs +++ b/crates/core/src/provider.rs @@ -1,6 +1,6 @@ //! Provider abstractions for the unified LLM Interfaces -use crate::{Chat, ChatMessage, Config, Response, StreamChunk}; +use crate::{Chat, ChatMessage, Config, General, Response, StreamChunk}; use anyhow::Result; use futures_core::Stream; use reqwest::Client; @@ -8,7 +8,7 @@ use reqwest::Client; /// A trait for LLM providers pub trait LLM: Sized + Clone { /// The chat configuration. - type ChatConfig: From; + type ChatConfig: Config; /// Create a new LLM provider fn new(client: Client, key: &str) -> Result @@ -16,12 +16,8 @@ pub trait LLM: Sized + Clone { Self: Sized; /// Create a new chat - fn chat(&self, config: Config) -> Chat { - Chat { - config: config.into(), - messages: Vec::new(), - provider: self.clone(), - } + fn chat(&self, config: General) -> Chat { + Chat::new(config, self.clone()) } /// Send a message to the LLM @@ -34,7 +30,8 @@ pub trait LLM: Sized + Clone { /// Send a message to the LLM with streaming fn stream( &mut self, - config: &Self::ChatConfig, + config: Self::ChatConfig, messages: &[ChatMessage], + usage: bool, ) -> impl Stream>; } diff --git a/crates/core/src/response.rs b/crates/core/src/response.rs index dce5960..76852db 100644 --- a/crates/core/src/response.rs +++ b/crates/core/src/response.rs @@ -42,6 +42,20 @@ impl Response { .first() .and_then(|choice| choice.message.reasoning_content.as_ref()) } + + /// Get the tool calls from the response + pub fn tool_calls(&self) -> Option<&[ToolCall]> { + self.choices + .first() + .and_then(|choice| choice.message.tool_calls.as_deref()) + } + + /// Get the reason the model stopped generating + pub fn reason(&self) -> Option<&FinishReason> { + self.choices + .first() + .and_then(|choice| choice.finish_reason.as_ref()) + } } /// A completion choice in a non-streaming response diff --git a/crates/core/src/stream.rs b/crates/core/src/stream.rs index 5a744e1..559b5b0 100644 --- a/crates/core/src/stream.rs +++ b/crates/core/src/stream.rs @@ -1,6 +1,6 @@ //! Streaming response abstractions for the unified LLM Interfaces -use crate::{Role, tool::ToolCall}; +use crate::{FinishReason, Role, tool::ToolCall}; use serde::Deserialize; /// A streaming chat completion chunk @@ -42,6 +42,20 @@ impl StreamChunk { .first() .and_then(|choice| choice.delta.reasoning_content.as_deref()) } + + /// Get the tool calls of the first choice + pub fn tool_calls(&self) -> Option<&[ToolCall]> { + self.choices + .first() + .and_then(|choice| choice.delta.tool_calls.as_deref()) + } + + /// Get the reason the model stopped generating + pub fn reason(&self) -> Option<&FinishReason> { + self.choices + .first() + .and_then(|choice| choice.finish_reason.as_ref()) + } } /// A completion choice in a streaming response diff --git a/crates/core/src/template.rs b/crates/core/src/template.rs deleted file mode 100644 index 5510338..0000000 --- a/crates/core/src/template.rs +++ /dev/null @@ -1,37 +0,0 @@ -//! Turbofish Agent library - -use crate::{Message, Role}; -use serde::{Deserialize, Serialize}; - -/// A template of the system prompt -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Template { - /// The system prompt for the agent - pub system: String, - - /// The input example - pub input: String, - - /// The output json example - pub output: String, -} - -impl Template { - /// Create a new message from the template - pub fn message(&self) -> Message { - Message { - content: format!( - r#"{} - - EXAMPLE INPUT: - {} - - EXAMPLE JSON OUTPUT: - {}"#, - self.system, self.input, self.output - ), - name: String::new(), - role: Role::System, - } - } -} diff --git a/crates/core/src/tool.rs b/crates/core/src/tool.rs index c36a607..96a46fa 100644 --- a/crates/core/src/tool.rs +++ b/crates/core/src/tool.rs @@ -44,7 +44,7 @@ pub struct FunctionCall { } /// Controls which tool is called by the model -#[derive(Debug, Clone, Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize, Default)] pub enum ToolChoice { /// Model will not call any tool #[serde(rename = "none")] @@ -52,6 +52,7 @@ pub enum ToolChoice { /// Model can pick between generating a message or calling tools #[serde(rename = "auto")] + #[default] Auto, /// Model must call one or more tools @@ -72,12 +73,11 @@ pub struct ToolChoiceFunction { pub name: String, } -impl ToolChoice { - /// Create a tool choice for a specific function - pub fn function(name: impl Into) -> Self { +impl From<&str> for ToolChoice { + fn from(value: &str) -> Self { ToolChoice::Function { r#type: "function".into(), - function: ToolChoiceFunction { name: name.into() }, + function: ToolChoiceFunction { name: value.into() }, } } } diff --git a/crates/ullm/Cargo.toml b/crates/ullm/Cargo.toml index 424190d..5dfdaa5 100644 --- a/crates/ullm/Cargo.toml +++ b/crates/ullm/Cargo.toml @@ -11,14 +11,3 @@ keywords.workspace = true [dependencies] deepseek.workspace = true ucore.workspace = true - -# crates-io dependencies -anyhow.workspace = true -clap.workspace = true -dirs.workspace = true -futures-util.workspace = true -serde.workspace = true -tokio.workspace = true -toml.workspace = true -tracing.workspace = true -tracing-subscriber.workspace = true diff --git a/crates/ullm/src/lib.rs b/crates/ullm/src/lib.rs index ea218b1..a02ea39 100644 --- a/crates/ullm/src/lib.rs +++ b/crates/ullm/src/lib.rs @@ -1,6 +1,6 @@ //! Unified LLM Interface - -pub mod cmd; +//! +//! This is the umbrella crate that re-exports all ullm components. pub use deepseek::DeepSeek; -pub use ucore::{Chat, ChatMessage, Client, Config, LLM, Message, Response, StreamChunk}; +pub use ucore::{self, Chat, ChatMessage, Client, Config, General, LLM, Message}; diff --git a/llm/deepseek/src/llm.rs b/llm/deepseek/src/llm.rs index c2bc608..93c8a75 100644 --- a/llm/deepseek/src/llm.rs +++ b/llm/deepseek/src/llm.rs @@ -6,7 +6,7 @@ use async_stream::try_stream; use futures_core::Stream; use futures_util::StreamExt; use ucore::{ - Chat, ChatMessage, Client, Config, LLM, Response, StreamChunk, + ChatMessage, Client, LLM, Response, StreamChunk, reqwest::{ Method, header::{self, HeaderMap}, @@ -17,7 +17,7 @@ const ENDPOINT: &str = "https://api.deepseek.com/chat/completions"; impl LLM for DeepSeek { /// The chat configuration. - type ChatConfig = Config; + type ChatConfig = Request; /// Create a new LLM provider fn new(client: Client, key: &str) -> Result { @@ -28,22 +28,13 @@ impl LLM for DeepSeek { Ok(Self { client, headers }) } - /// Create a new chat - fn chat(&self, config: Config) -> Chat { - Chat { - config, - messages: Vec::new(), - provider: self.clone(), - } - } - /// Send a message to the LLM - async fn send(&mut self, config: &Config, messages: &[ChatMessage]) -> Result { + async fn send(&mut self, req: &Request, messages: &[ChatMessage]) -> Result { let text = self .client .request(Method::POST, ENDPOINT) .headers(self.headers.clone()) - .json(&Request::from(config).messages(messages)) + .json(&req.messages(messages)) .send() .await? .text() @@ -64,18 +55,15 @@ impl LLM for DeepSeek { /// Send a message to the LLM with streaming fn stream( &mut self, - config: &Config, + req: Request, messages: &[ChatMessage], + usage: bool, ) -> impl Stream> { let request = self .client .request(Method::POST, ENDPOINT) .headers(self.headers.clone()) - .json( - &Request::from(config) - .messages(messages) - .stream(config.usage), - ); + .json(&req.messages(messages).stream(usage)); try_stream! { let mut stream = request.send().await?.bytes_stream(); diff --git a/llm/deepseek/src/request.rs b/llm/deepseek/src/request.rs index f428ec3..2fc10d4 100644 --- a/llm/deepseek/src/request.rs +++ b/llm/deepseek/src/request.rs @@ -1,77 +1,79 @@ //! The request body for the DeepSeek API use serde::Serialize; -use serde_json::{Number, Value, json}; -use ucore::{ChatMessage, Config, Tool}; +use serde_json::{Value, json}; +use ucore::{ChatMessage, Config, General, Tool, ToolChoice}; /// The request body for the DeepSeek API #[derive(Debug, Clone, Serialize)] pub struct Request { - /// The frequency penalty to use for the response - #[serde(skip_serializing_if = "Value::is_null")] - pub frequency_penalty: Value, - - /// Whether to return the log probabilities - #[serde(skip_serializing_if = "Value::is_null")] - pub logprobs: Value, - - /// The maximum number of tokens to generate - pub max_tokens: usize, - /// The messages to send to the API pub messages: Vec, /// The model we are using pub model: String, + /// The frequency penalty to use for the response + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + + /// Whether to return the log probabilities + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + + /// The maximum number of tokens to generate + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + /// The presence penalty to use for the response - #[serde(skip_serializing_if = "Value::is_null")] - pub presence_penalty: Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, /// The response format to use - #[serde(skip_serializing_if = "Value::is_null")] - pub response_format: Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, /// Stop sequences - #[serde(skip_serializing_if = "Value::is_null")] - pub stop: Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, /// Whether to stream the response - pub stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, /// Stream options - #[serde(skip_serializing_if = "Value::is_null")] - pub stream_options: Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, /// Whether to enable thinking - #[serde(skip_serializing_if = "Value::is_null")] - pub thinking: Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub thinking: Option, /// The temperature to use for the response - #[serde(skip_serializing_if = "Value::is_null")] - pub temperature: Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, /// Controls which (if any) tool is called by the model - #[serde(skip_serializing_if = "Value::is_null")] - pub tool_choice: Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, /// A list of tools the model may call - #[serde(skip_serializing_if = "Value::is_null")] - pub tools: Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option, /// An integer between 0 and 20 specifying the number of most likely tokens to /// return at each token position, each with an associated log probability. - #[serde(skip_serializing_if = "Value::is_null")] - pub top_logprobs: Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_logprobs: Option, /// The top probability to use for the response - #[serde(skip_serializing_if = "Value::is_null")] - pub top_p: Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, } impl Request { /// Construct the messages for the request - pub fn messages(&mut self, messages: &[ChatMessage]) -> Self { + pub fn messages(&self, messages: &[ChatMessage]) -> Self { Self { messages: messages.to_vec(), ..self.clone() @@ -80,77 +82,51 @@ impl Request { /// Enable streaming for the request pub fn stream(mut self, usage: bool) -> Self { - self.stream = true; + self.stream = Some(true); self.stream_options = if usage { - json!({ "include_usage": true }) + Some(json!({ "include_usage": true })) } else { - Value::Null + None }; self } } -impl From<&Config> for Request { - fn from(config: &Config) -> Self { +impl From for Request { + fn from(config: General) -> Self { Self { - frequency_penalty: Number::from_f64(config.frequency as f64) - .map(Value::Number) - .unwrap_or(Value::Null), - logprobs: if config.logprobs { - Value::Bool(true) - } else { - Value::Null - }, - max_tokens: config.tokens, messages: Vec::new(), model: config.model.clone(), - presence_penalty: Number::from_f64(config.presence as f64) - .map(Value::Number) - .unwrap_or(Value::Null), - response_format: if config.json { - json!({ "type": "json_object" }) - } else { - Value::Null - }, - stop: if config.stop.is_empty() { - Value::Null - } else { - config.stop.iter().map(|s| json!(s)).collect() - }, - stream: false, - stream_options: Value::Null, - temperature: Number::from_f64(config.temperature as f64) - .map(Value::Number) - .unwrap_or(Value::Null), - thinking: if config.think { - json!({ "type": "enabled" }) - } else { - Value::Null - }, - tool_choice: serde_json::to_value(&config.tool_choice).unwrap_or(Value::Null), - tools: serialize_tools(&config.tools), - top_logprobs: if config.logprobs { - Value::Number(config.top_logprobs.into()) - } else { - Value::Null - }, - top_p: Number::from_f64(config.top_p as f64) - .map(Value::Number) - .unwrap_or(Value::Null), + frequency_penalty: None, + logprobs: None, + max_tokens: None, + presence_penalty: None, + response_format: None, + stop: None, + stream: None, + stream_options: None, + thinking: None, + temperature: None, + tool_choice: None, + tools: None, + top_logprobs: None, + top_p: None, } } } -/// Serialize tools to JSON value -fn serialize_tools(tools: &[Tool]) -> Value { - if tools.is_empty() { - return Value::Null; +impl Config for Request { + fn with_tools(self, tools: Vec) -> Self { + Self { + tools: Some(json!(tools)), + ..self.clone() + } } - let tools: Vec = tools - .iter() - .map(|tool| json!({ "type": "function", "function": tool })) - .collect(); - - Value::Array(tools) + fn with_tool_choice(&self, tool_choice: ToolChoice) -> Self { + Self { + tool_choice: Some(json!(tool_choice)), + ..self.clone() + } + } }