Skip to content

Commit 75a3e76

Browse files
chaynaborsbrandonskiserdingfeli
authored andcommitted
fix chat help (#1830)
Co-authored-by: Brandon Kiser <[email protected]> Co-authored-by: Felix Ding <[email protected]>
1 parent f414ddc commit 75a3e76

File tree

10 files changed

+135
-167
lines changed

10 files changed

+135
-167
lines changed

crates/chat-cli/src/cli/chat/cli.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ pub struct Chat {
1717
/// prompt requests permissions to use a tool, unless --trust-all-tools is also used.
1818
#[arg(long)]
1919
pub no_interactive: bool,
20-
/// Start a new conversation and overwrites any previous conversation from this directory.
21-
#[arg(long)]
22-
pub new: bool,
20+
/// Resumes the previous conversation from this directory.
21+
#[arg(short, long)]
22+
pub resume: bool,
2323
/// The first question to ask
2424
pub input: Option<String>,
2525
/// Context profile to use

crates/chat-cli/src/cli/chat/command.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ pub enum Command {
5151
subcommand: Option<PromptsSubcommand>,
5252
},
5353
Usage,
54-
Import {
54+
Load {
5555
path: String,
5656
},
57-
Export {
57+
Save {
5858
path: String,
5959
force: bool,
6060
},
@@ -818,15 +818,15 @@ impl Command {
818818
}
819819
},
820820
"usage" => Self::Usage,
821-
"import" => {
821+
"load" => {
822822
let Some(path) = parts.get(1) else {
823823
return Err("path is required".to_string());
824824
};
825-
Self::Import {
825+
Self::Load {
826826
path: (*path).to_string(),
827827
}
828828
},
829-
"export" => {
829+
"save" => {
830830
let force = parts.contains(&"-f") || parts.contains(&"--force");
831831
let Some(path) = parts.get(1) else {
832832
return Err("path is required".to_string());
@@ -835,7 +835,7 @@ impl Command {
835835
if !path.ends_with(".json") {
836836
path.push_str(".json");
837837
}
838-
Self::Export { path, force }
838+
Self::Save { path, force }
839839
},
840840
unknown_command => {
841841
let looks_like_path = {

crates/chat-cli/src/cli/chat/conversation_state.rs

Lines changed: 48 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::collections::{
22
HashMap,
3+
HashSet,
34
VecDeque,
45
};
56
use std::sync::Arc;
@@ -32,7 +33,6 @@ use super::hooks::{
3233
};
3334
use super::message::{
3435
AssistantMessage,
35-
AssistantToolUse,
3636
ToolUseResult,
3737
ToolUseResultBlock,
3838
UserMessage,
@@ -60,7 +60,6 @@ use crate::api_client::model::{
6060
ToolInputSchema,
6161
ToolResult,
6262
ToolResultContentBlock,
63-
ToolResultStatus,
6463
ToolSpecification,
6564
ToolUse,
6665
UserInputMessage,
@@ -347,7 +346,7 @@ impl ConversationState {
347346
}
348347
}
349348

350-
self.enforce_tool_use_history_invariants(true);
349+
self.enforce_tool_use_history_invariants();
351350
}
352351

353352
/// Here we also need to make sure that the tool result corresponds to one of the tools
@@ -362,105 +361,52 @@ impl ConversationState {
362361
/// intervention here is to substitute the ambiguous, partial name with a dummy.
363362
/// 3. The model had decided to call a tool that does not exist. The intervention here is to
364363
/// substitute the non-existent tool name with a dummy.
365-
pub fn enforce_tool_use_history_invariants(&mut self, last_only: bool) {
366-
let tool_name_list = self.tool_manager.tn_map.keys().map(String::as_str).collect::<Vec<_>>();
367-
// We need to first determine what the range of interest is. There are two places where we
368-
// would call this function:
369-
// 1. When there are changes to the list of available tools, in which case we comb through the
370-
// entire conversation
371-
// 2. When we send a message, in which case we only examine the most recent entry
372-
let (tool_use_results, mut tool_uses) = if last_only {
373-
if let (Some((_, AssistantMessage::ToolUse { ref mut tool_uses, .. })), Some(user_msg)) = (
374-
self.history
375-
.range_mut(self.valid_history_range.0..self.valid_history_range.1)
376-
.last(),
377-
&mut self.next_message,
378-
) {
379-
let tool_use_results = user_msg
380-
.tool_use_results()
381-
.map_or(Vec::new(), |results| results.iter().collect::<Vec<_>>());
382-
let tool_uses = tool_uses.iter_mut().collect::<Vec<_>>();
383-
(tool_use_results, tool_uses)
384-
} else {
385-
(Vec::new(), Vec::new())
386-
}
387-
} else {
388-
let tool_use_results = self.next_message.as_ref().map_or(Vec::new(), |user_msg| {
389-
user_msg
390-
.tool_use_results()
391-
.map_or(Vec::new(), |results| results.iter().collect::<Vec<_>>())
392-
});
393-
self.history
394-
.iter_mut()
395-
.filter_map(|(user_msg, asst_msg)| {
396-
if let (Some(tool_use_results), AssistantMessage::ToolUse { ref mut tool_uses, .. }) =
397-
(user_msg.tool_use_results(), asst_msg)
398-
{
399-
Some((tool_use_results, tool_uses))
400-
} else {
401-
None
402-
}
364+
pub fn enforce_tool_use_history_invariants(&mut self) {
365+
let tool_names: HashSet<_> = self
366+
.tools
367+
.values()
368+
.flat_map(|tools| {
369+
tools.iter().map(|tool| match tool {
370+
Tool::ToolSpecification(tool_specification) => tool_specification.name.as_str(),
403371
})
404-
.fold(
405-
(tool_use_results, Vec::<&mut AssistantToolUse>::new()),
406-
|(mut tool_use_results, mut tool_uses), (results, uses)| {
407-
let mut results = results.iter().collect::<Vec<_>>();
408-
let mut uses = uses.iter_mut().collect::<Vec<_>>();
409-
tool_use_results.append(&mut results);
410-
tool_uses.append(&mut uses);
411-
(tool_use_results, tool_uses)
412-
},
413-
)
414-
};
372+
})
373+
.filter(|name| *name != DUMMY_TOOL_NAME)
374+
.collect();
375+
376+
for (_, assistant) in &mut self.history {
377+
if let AssistantMessage::ToolUse { ref mut tool_uses, .. } = assistant {
378+
for tool_use in tool_uses {
379+
if tool_names.contains(tool_use.name.as_str()) {
380+
continue;
381+
}
415382

416-
// Replace tool uses associated with tools that does not exist / no longer exists with
417-
// dummy (i.e. put them to sleep / dormant)
418-
for result in tool_use_results {
419-
let tool_use_id = result.tool_use_id.as_str();
420-
let corresponding_tool_use = tool_uses.iter_mut().find(|tool_use| tool_use_id == tool_use.id);
421-
if let Some(tool_use) = corresponding_tool_use {
422-
if tool_name_list.contains(&tool_use.name.as_str()) {
423-
// If this tool matches of the tools in our list, this is not our
424-
// concern, error or not.
425-
continue;
426-
}
427-
if let ToolResultStatus::Error = result.status {
428-
// case 2 and 3
429-
tool_use.name = DUMMY_TOOL_NAME.to_string();
430-
tool_use.args = serde_json::json!({});
431-
} else {
432-
// case 1
433-
let full_name = tool_name_list.iter().find(|name| name.ends_with(&tool_use.name));
434-
// We should be able to find a match but if not we'll just treat it as
435-
// a dummy and move on
436-
if let Some(full_name) = full_name {
437-
tool_use.name = (*full_name).to_string();
438-
} else {
439-
tool_use.name = DUMMY_TOOL_NAME.to_string();
440-
tool_use.args = serde_json::json!({});
383+
if tool_names.contains(tool_use.orig_name.as_str()) {
384+
tool_use.name = tool_use.orig_name.clone();
385+
tool_use.args = tool_use.orig_args.clone();
386+
continue;
441387
}
442-
}
443-
}
444-
}
445388

446-
// Revive tools that were previously dormant if they now corresponds to one of the tools in
447-
// our list of available tools. Note that this check only works because tn_map does NOT
448-
// contain names of native tools.
449-
for tool_use in tool_uses {
450-
if tool_use.name == DUMMY_TOOL_NAME
451-
&& tool_use
452-
.orig_name
453-
.as_ref()
454-
.is_some_and(|name| tool_name_list.contains(&(*name).as_str()))
455-
{
456-
tool_use.name = tool_use
457-
.orig_name
458-
.as_ref()
459-
.map_or(DUMMY_TOOL_NAME.to_string(), |name| name.clone());
460-
tool_use.args = tool_use
461-
.orig_args
462-
.as_ref()
463-
.map_or(serde_json::json!({}), |args| args.clone());
389+
let names: Vec<&str> = tool_names
390+
.iter()
391+
.filter_map(|name| {
392+
if name.ends_with(&tool_use.name) {
393+
Some(*name)
394+
} else {
395+
None
396+
}
397+
})
398+
.collect();
399+
400+
// There's only one tool use matching, so we can just replace it with the
401+
// found name.
402+
if names.len() == 1 {
403+
tool_use.name = (*names.first().unwrap()).to_string();
404+
continue;
405+
}
406+
407+
// Otherwise, we have to replace it with a dummy.
408+
tool_use.name = DUMMY_TOOL_NAME.to_string();
409+
}
464410
}
465411
}
466412
}
@@ -514,8 +460,8 @@ impl ConversationState {
514460
.expect("unable to construct conversation state")
515461
}
516462

517-
pub async fn update_state(&mut self) {
518-
let needs_update = self.tool_manager.has_new_stuff.load(Ordering::Acquire);
463+
pub async fn update_state(&mut self, force_update: bool) {
464+
let needs_update = self.tool_manager.has_new_stuff.load(Ordering::Acquire) || force_update;
519465
if !needs_update {
520466
return;
521467
}
@@ -540,12 +486,13 @@ impl ConversationState {
540486
// We call this in [Self::enforce_conversation_invariants] as well. But we need to call it
541487
// here as well because when it's being called in [Self::enforce_conversation_invariants]
542488
// it is only checking the last entry.
543-
self.enforce_tool_use_history_invariants(false);
489+
self.enforce_tool_use_history_invariants();
544490
}
545491

546492
/// Returns a conversation state representation which reflects the exact conversation to send
547493
/// back to the model.
548494
pub async fn backend_conversation_state(&mut self, run_hooks: bool, quiet: bool) -> BackendConversationState<'_> {
495+
self.update_state(false).await;
549496
self.enforce_conversation_invariants();
550497

551498
// Run hooks and add to conversation start and next user message.

crates/chat-cli/src/cli/chat/message.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,11 +349,11 @@ pub struct AssistantToolUse {
349349
/// The name for the tool as exposed to the model
350350
pub name: String,
351351
/// Original name of the tool
352-
pub orig_name: Option<String>,
352+
pub orig_name: String,
353353
/// The input to pass to the tool as exposed to the model
354354
pub args: serde_json::Value,
355355
/// Original input passed to the tool
356-
pub orig_args: Option<serde_json::Value>,
356+
pub orig_args: serde_json::Value,
357357
}
358358

359359
impl From<AssistantToolUse> for ToolUse {

0 commit comments

Comments
 (0)