Skip to content

Commit 47ff5ee

Browse files
committed
feat: tool management
1 parent aeb6d1f commit 47ff5ee

File tree

9 files changed

+440
-291
lines changed

9 files changed

+440
-291
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ target
22
Cargo.lock
33
**/*.rs.bk
44
.DS_Store
5+
.env
56

67
# directory used to store images
78
data

async-openai/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ secrecy = { version = "0.10.3", features = ["serde"] }
5050
bytes = "1.9.0"
5151
eventsource-stream = "0.2.3"
5252
tokio-tungstenite = { version = "0.26.1", optional = true, default-features = false }
53+
schemars = "0.8.22"
5354

5455
[dev-dependencies]
5556
tokio-test = "0.4.4"

async-openai/src/client.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,6 @@ where
485485
if message.data == "[DONE]" {
486486
break;
487487
}
488-
489488
let response = match serde_json::from_str::<O>(&message.data) {
490489
Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
491490
Ok(output) => Ok(output),

async-openai/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ mod projects;
149149
mod runs;
150150
mod steps;
151151
mod threads;
152+
pub mod tools;
152153
pub mod traits;
153154
pub mod types;
154155
mod uploads;

async-openai/src/tools.rs

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
//! This module provides functionality for managing and executing tools in an async OpenAI context.
2+
//! It defines traits and structures for tool management, execution, and streaming.
3+
use std::{
4+
collections::{BTreeMap, HashMap},
5+
future::Future,
6+
pin::Pin,
7+
sync::Arc,
8+
};
9+
10+
use schemars::{schema_for, JsonSchema};
11+
use serde::{Deserialize, Serialize};
12+
use serde_json::json;
13+
14+
use crate::types::{
15+
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
16+
ChatCompletionRequestToolMessage, ChatCompletionTool, ChatCompletionToolType, FunctionCall,
17+
FunctionObject,
18+
};
19+
20+
/// A trait defining the interface for tools that can be used with the OpenAI API.
21+
/// Tools must implement this trait to be used with the ToolManager.
22+
pub trait Tool: Send + Sync {
23+
/// The type of arguments that the tool accepts.
24+
type Args: JsonSchema + for<'a> Deserialize<'a> + Send + Sync;
25+
/// The type of output that the tool produces.
26+
type Output: Serialize + Send + Sync;
27+
/// The type of error that the tool can return.
28+
type Error: ToString + Send + Sync;
29+
30+
/// Returns the name of the tool.
31+
fn name() -> String {
32+
Self::Args::schema_name()
33+
}
34+
35+
/// Returns an optional description of the tool.
36+
fn description() -> Option<String> {
37+
None
38+
}
39+
40+
/// Returns an optional boolean indicating whether the tool should be strict about the arguments.
41+
fn strict() -> Option<bool> {
42+
None
43+
}
44+
45+
/// Creates a ChatCompletionTool definition for the tool.
46+
fn definition() -> ChatCompletionTool {
47+
ChatCompletionTool {
48+
r#type: ChatCompletionToolType::Function,
49+
function: FunctionObject {
50+
name: Self::name(),
51+
description: Self::description(),
52+
parameters: Some(json!(schema_for!(Self::Args))),
53+
strict: Self::strict(),
54+
},
55+
}
56+
}
57+
58+
/// Executes the tool with the given arguments.
59+
/// Returns a Future that resolves to either the tool's output or an error.
60+
fn call(
61+
&self,
62+
args: Self::Args,
63+
) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send;
64+
}
65+
66+
/// A dynamic trait for tools that allows for runtime tool management.
67+
/// This trait provides a way to work with tools without knowing their concrete types at compile time.
68+
pub trait ToolDyn: Send + Sync {
69+
/// Returns the tool's definition as a ChatCompletionTool.
70+
fn definition(&self) -> ChatCompletionTool;
71+
72+
/// Executes the tool with the given JSON string arguments.
73+
/// Returns a Future that resolves to either a JSON string output or an error string.
74+
fn call(
75+
&self,
76+
args: String,
77+
) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>>;
78+
}
79+
80+
// Implementation of ToolDyn for any type that implements Tool
81+
impl<T: Tool> ToolDyn for T {
82+
fn definition(&self) -> ChatCompletionTool {
83+
T::definition()
84+
}
85+
86+
fn call(
87+
&self,
88+
args: String,
89+
) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
90+
let future = async move {
91+
// Special handling for T::Args = () case
92+
// If the tool doesn't require arguments (T::Args is unit type),
93+
// we can safely ignore the provided arguments string
94+
match serde_json::from_str::<T::Args>(&args)
95+
.or_else(|e| serde_json::from_str::<T::Args>("null").map_err(|_| e))
96+
{
97+
Ok(args) => T::call(self, args)
98+
.await
99+
.map_err(|e| e.to_string())
100+
.and_then(|output| {
101+
serde_json::to_string(&output)
102+
.map_err(|e| format!("Failed to serialize output: {}", e))
103+
}),
104+
Err(e) => Err(format!("Failed to parse arguments: {}", e)),
105+
}
106+
};
107+
Box::pin(future)
108+
}
109+
}
110+
111+
/// A manager for tools that allows adding, retrieving, and executing tools.
112+
#[derive(Default, Clone)]
113+
pub struct ToolManager {
114+
/// A map of tool names to their dynamic implementations.
115+
tools: BTreeMap<String, Arc<dyn ToolDyn>>,
116+
}
117+
118+
impl ToolManager {
119+
/// Creates a new ToolManager.
120+
pub fn new() -> Self {
121+
Self {
122+
tools: BTreeMap::new(),
123+
}
124+
}
125+
126+
/// Adds a new tool to the manager.
127+
pub fn add_tool<T: Tool + 'static>(&mut self, tool: T) {
128+
self.tools
129+
.insert(T::name(), Arc::new(tool));
130+
}
131+
132+
/// Adds a new tool with an Arc to the manager.
133+
///
134+
/// Use this if you want to access this tool after being added to the manager.
135+
pub fn add_tool_dyn(&mut self, tool: Arc<dyn ToolDyn>) {
136+
self.tools.insert(tool.definition().function.name, tool);
137+
}
138+
139+
/// Removes a tool from the manager.
140+
pub fn remove_tool(&mut self, name: &str) -> bool {
141+
self.tools.remove(name).is_some()
142+
}
143+
144+
/// Returns the definitions of all tools in the manager.
145+
pub fn get_tools(&self) -> Vec<ChatCompletionTool> {
146+
self.tools.values().map(|tool| tool.definition()).collect()
147+
}
148+
149+
/// Executes multiple tool calls concurrently and returns their results.
150+
pub async fn call(
151+
&self,
152+
calls: impl IntoIterator<Item = ChatCompletionMessageToolCall>,
153+
) -> Vec<ChatCompletionRequestToolMessage> {
154+
let mut handles = Vec::new();
155+
let mut outputs = Vec::new();
156+
157+
// Spawn a task for each tool call
158+
for call in calls {
159+
if let Some(tool) = self.tools.get(&call.function.name).cloned() {
160+
let handle = tokio::spawn(async move { tool.call(call.function.arguments).await });
161+
handles.push((call.id, handle));
162+
} else {
163+
outputs.push(ChatCompletionRequestToolMessage {
164+
content: "Tool call failed: tool not found".into(),
165+
tool_call_id: call.id,
166+
});
167+
}
168+
}
169+
170+
// Collect results from all spawned tasks
171+
for (id, handle) in handles {
172+
let output = match handle.await {
173+
Ok(Ok(output)) => output,
174+
Ok(Err(e)) => {
175+
format!("Tool call failed: {}", e)
176+
}
177+
Err(_) => "Tool call failed: runtime error".to_string(),
178+
};
179+
outputs.push(ChatCompletionRequestToolMessage {
180+
content: output.into(),
181+
tool_call_id: id,
182+
});
183+
}
184+
outputs
185+
}
186+
}
187+
188+
/// A manager for handling streaming tool calls.
189+
/// This structure helps manage and merge tool call chunks that arrive in a streaming fashion.
190+
#[derive(Default, Clone, Debug)]
191+
pub struct ToolCallStreamManager(HashMap<u32, ChatCompletionMessageToolCall>);
192+
193+
impl ToolCallStreamManager {
194+
/// Creates a new empty ToolCallStreamManager.
195+
pub fn new() -> Self {
196+
Self(HashMap::new())
197+
}
198+
199+
/// Processes a single streaming tool call chunk and merges it with existing data.
200+
pub fn process_chunk(&mut self, chunk: ChatCompletionMessageToolCallChunk) {
201+
let tool_call =
202+
self.0
203+
.entry(chunk.index)
204+
.or_insert_with(|| ChatCompletionMessageToolCall {
205+
id: "".to_string(),
206+
r#type: ChatCompletionToolType::Function,
207+
function: FunctionCall {
208+
name: "".to_string(),
209+
arguments: "".to_string(),
210+
},
211+
});
212+
if let Some(id) = chunk.id {
213+
tool_call.id = id;
214+
}
215+
if let Some(function) = chunk.function {
216+
if let Some(name) = function.name {
217+
tool_call.function.name = name;
218+
}
219+
if let Some(arguments) = function.arguments {
220+
tool_call.function.arguments.push_str(&arguments);
221+
}
222+
}
223+
}
224+
225+
/// Processes multiple streaming tool call chunks and merges them with existing data.
226+
pub fn process_chunks(
227+
&mut self,
228+
chunks: impl IntoIterator<Item = ChatCompletionMessageToolCallChunk>,
229+
) {
230+
for chunk in chunks {
231+
self.process_chunk(chunk);
232+
}
233+
}
234+
235+
/// Returns all completed tool calls as a vector.
236+
pub fn finish_stream(self) -> Vec<ChatCompletionMessageToolCall> {
237+
self.0
238+
.into_values()
239+
.filter(|tool_call| {
240+
let is_complete = !tool_call.id.is_empty() && !tool_call.function.name.is_empty();
241+
if !is_complete {
242+
tracing::error!("Tool call is not complete: {:?}", tool_call);
243+
}
244+
is_complete
245+
})
246+
.collect()
247+
}
248+
}

examples/tool-call-stream/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ publish = false
77
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
88

99
[dependencies]
10-
async-openai = {path = "../../async-openai"}
10+
async-openai = { path = "../../async-openai" }
1111
rand = "0.8.5"
12+
serde = "1.0"
1213
serde_json = "1.0.135"
1314
tokio = { version = "1.43.0", features = ["full"] }
1415
futures = "0.3.31"
16+
schemars = "0.8.22"

0 commit comments

Comments
 (0)