Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: tool calling shim for models without native tool calling / structured output capabilities #1147

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions crates/goose-cli/src/commands/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ pub async fn configure_provider_dialog() -> Result<bool, Box<dyn Error>> {
}
Err(e) => {
spin.stop(style(e.to_string()).red());
println!("{:?}", e);
cliclack::outro(style("Failed to configure provider: init chat completion request with tool did not succeed.").on_red().white())?;
Ok(false)
}
Expand Down
14 changes: 14 additions & 0 deletions crates/goose-mcp/src/developer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,18 @@ impl DeveloperRouter {
}),
);

let ask_user_next_step_tool = Tool::new(
"ask_user_next_step",
indoc! {r#"
Request user input for what to do next without performing any other actions.
"#},
json!({
"type": "object",
"required": [],
"properties": {}
}),
);

// Get base instructions and working directory
let cwd = std::env::current_dir().expect("should have a current working dir");
let base_instructions = formatdoc! {r#"
Expand Down Expand Up @@ -193,6 +205,7 @@ impl DeveloperRouter {
text_editor_tool,
list_windows_tool,
screen_capture_tool,
ask_user_next_step_tool,
],
file_history: Arc::new(Mutex::new(HashMap::new())),
instructions,
Expand Down Expand Up @@ -665,6 +678,7 @@ impl Router for DeveloperRouter {
"text_editor" => this.text_editor(arguments).await,
"list_windows" => this.list_windows(arguments).await,
"screen_capture" => this.screen_capture(arguments).await,
"ask_user_next_step" => Ok(vec![Content::text("Waiting for user input").with_audience(vec![Role::Assistant])]),
_ => Err(ToolError::NotFound(format!("Tool {} not found", tool_name))),
}
})
Expand Down
10 changes: 5 additions & 5 deletions crates/goose/src/prompts/system.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
You are a general purpose AI agent called Goose. You are capable
You are a general purpose AI agent called Goose. You collaborate with users who are capable
of dynamically plugging into new extensions and learning how to use them.

You solve higher level problems using the tools in these extensions, and can
interact with multiple at once.
You collaborate with a user to solve higher level problems using the tools in these extensions, and can
interact with multiple at once. You must tell the user what extension and tools to use and then wait for the user to use them and tell you the result.

{% if (extensions is defined) and extensions %}
Because you dynamically load extensions, your conversation history may refer
Because the user may dynamically load extensions, your conversation history may refer
to interactions with extensions that are not currently active. The currently
active extensions are below. Each of these extensions provides tools that are
in your tool specification.
in your user's tool specification.

# Extensions:
{% for extension in extensions %}
Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod oauth;
pub mod ollama;
pub mod openai;
pub mod openrouter;
pub mod tool_parser;
pub mod utils;

pub use factory::{create, providers};
97 changes: 86 additions & 11 deletions crates/goose/src/providers/ollama.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
use super::errors::ProviderError;
use super::utils::{get_model, handle_response_openai_compat};
use crate::message::Message;
use super::tool_parser::ToolParserProvider;
use crate::message::{Message, MessageContent};
use crate::model::ModelConfig;
use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
use anyhow::Result;
use async_trait::async_trait;
use mcp_core::tool::Tool;
use mcp_core::{role::Role, tool::Tool, content::TextContent};
use reqwest::Client;
use serde_json::Value;
use std::time::Duration;
use url::Url;
use regex::Regex;
use chrono::Utc;

pub const OLLAMA_HOST: &str = "localhost";
pub const OLLAMA_DEFAULT_PORT: u16 = 11434;
pub const OLLAMA_DEFAULT_MODEL: &str = "qwen2.5";
// Ollama can run many models, we only provide the default
// Ollama can run many models, we suggest the default
pub const OLLAMA_KNOWN_MODELS: &[&str] = &[OLLAMA_DEFAULT_MODEL];
pub const OLLAMA_DOC_URL: &str = "https://ollama.com/library";

Expand All @@ -25,11 +28,13 @@ pub struct OllamaProvider {
client: Client,
host: String,
model: ModelConfig,
#[serde(skip)]
tool_parser: ToolParserProvider,
}

impl Default for OllamaProvider {
fn default() -> Self {
let model = ModelConfig::new(OllamaProvider::metadata().default_model);
let model = ModelConfig::new(OllamaProvider::metadata().default_model.to_string());
OllamaProvider::from_env(model).expect("Failed to initialize Ollama provider")
}
}
Expand All @@ -49,6 +54,7 @@ impl OllamaProvider {
client,
host,
model,
tool_parser: ToolParserProvider::default(),
})
}

Expand Down Expand Up @@ -82,6 +88,80 @@ impl OllamaProvider {
}
}

fn create_request_with_tools(
model_config: &ModelConfig,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> anyhow::Result<Value, anyhow::Error> {
let mut modified_system = system.to_string();
if !tools.is_empty() {
// For providers without native tool calling, embed the list of tools directly into the system prompt.
modified_system.push_str("\nAvailable tools: ");
let tools_text = serde_json::to_string_pretty(&tools)
.unwrap_or_else(|_| "[Error serializing tools]".to_string());
modified_system.push_str(&tools_text);
modified_system.push_str("\nWhen you want to use a tool, respond with a JSON object in this format: { \"tool\": \"tool_name\", \"args\": { \"arg1\": \"value1\", ... } }");
}

create_request(
model_config,
&modified_system,
messages,
tools,
&super::utils::ImageFormat::OpenAi,
)
}

async fn process_tool_calls(message: Message, tool_parser: &ToolParserProvider, tools: &[Tool]) -> Message {
let mut processed = Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![],
};

// Extract tool calls from the message content
let text = message.as_concat_text();
if !text.is_empty() {
let re = Regex::new(r"\{[^{}]*\}").unwrap(); // Basic regex to find JSON-like structures
let mut found_valid_json = false;

for cap in re.find_iter(&text) {
if let Ok(json) = serde_json::from_str::<Value>(cap.as_str()) {
if let (Some(tool), Some(args)) = (json.get("tool"), json.get("args")) {
if let (Some(_tool_name), Some(_args_obj)) = (tool.as_str(), args.as_object()) {
found_valid_json = true;
processed.content.push(MessageContent::Text(TextContent {
text: serde_json::to_string(&json).unwrap(),
annotations: None,
}));
}
}
}
}

// If no valid JSON was found, try using the tool parser
if !found_valid_json {
if let Ok(tool_calls) = tool_parser.parse_tool_calls(&text, tools).await {
for tool_call in tool_calls {
processed.content.push(MessageContent::Text(TextContent {
text: serde_json::to_string(&tool_call).unwrap(),
annotations: None,
}));
}
} else {
// If tool parser fails, pass through the original text
processed.content.push(MessageContent::Text(TextContent {
text: text,
annotations: None,
}));
}
}
}

processed
}

#[async_trait]
impl Provider for OllamaProvider {
fn metadata() -> ProviderMetadata {
Expand Down Expand Up @@ -115,17 +195,12 @@ impl Provider for OllamaProvider {
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let payload = create_request(
&self.model,
system,
messages,
tools,
&super::utils::ImageFormat::OpenAi,
)?;
let payload = create_request_with_tools(&self.model, system, messages, tools)?;
let response = self.post(payload.clone()).await?;

// Parse response
let message = response_to_message(response.clone())?;
let message = process_tool_calls(message, &self.tool_parser, tools).await;
let usage = match get_usage(&response) {
Ok(usage) => usage,
Err(ProviderError::UsageError(e)) => {
Expand Down
Loading