Skip to content
37 changes: 24 additions & 13 deletions crates/chat-cli/src/cli/chat/conversation_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ use crate::api_client::model::{
UserInputMessageContext,
};
use crate::cli::chat::util::shared_writer::SharedWriter;
use crate::database::Database;
use crate::mcp_client::Prompt;
use crate::platform::Context;

Expand Down Expand Up @@ -222,12 +223,16 @@ impl ConversationState {
}

/// Sets the response message according to the currently set [Self::next_message].
pub fn push_assistant_message(&mut self, message: AssistantMessage) {
pub fn push_assistant_message(&mut self, message: AssistantMessage, database: &mut Database) {
debug_assert!(self.next_message.is_some(), "next_message should exist");
let next_user_message = self.next_message.take().expect("next user message should exist");

self.append_assistant_transcript(&message);
self.history.push_back((next_user_message, message));

if let Ok(cwd) = std::env::current_dir() {
database.set_conversation_by_path(cwd, self).ok();
}
}

/// Returns the conversation id.
Expand Down Expand Up @@ -951,7 +956,8 @@ mod tests {
for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) {
let s = conversation_state.as_sendable_conversation_state(true).await;
assert_conversation_state_invariants(s, i);
conversation_state.push_assistant_message(AssistantMessage::new_response(None, i.to_string()));
conversation_state
.push_assistant_message(AssistantMessage::new_response(None, i.to_string()), &mut database);
conversation_state.set_next_user_message(i.to_string()).await;
}
}
Expand All @@ -977,13 +983,14 @@ mod tests {
let s = conversation_state.as_sendable_conversation_state(true).await;
assert_conversation_state_invariants(s, i);

conversation_state.push_assistant_message(AssistantMessage::new_tool_use(None, i.to_string(), vec![
AssistantToolUse {
conversation_state.push_assistant_message(
AssistantMessage::new_tool_use(None, i.to_string(), vec![AssistantToolUse {
id: "tool_id".to_string(),
name: "tool name".to_string(),
args: serde_json::Value::Null,
},
]));
}]),
&mut database,
);
conversation_state.add_tool_results(vec![ToolUseResult {
tool_use_id: "tool_id".to_string(),
content: vec![],
Expand All @@ -1005,20 +1012,22 @@ mod tests {
let s = conversation_state.as_sendable_conversation_state(true).await;
assert_conversation_state_invariants(s, i);
if i % 3 == 0 {
conversation_state.push_assistant_message(AssistantMessage::new_tool_use(None, i.to_string(), vec![
AssistantToolUse {
conversation_state.push_assistant_message(
AssistantMessage::new_tool_use(None, i.to_string(), vec![AssistantToolUse {
id: "tool_id".to_string(),
name: "tool name".to_string(),
args: serde_json::Value::Null,
},
]));
}]),
&mut database,
);
conversation_state.add_tool_results(vec![ToolUseResult {
tool_use_id: "tool_id".to_string(),
content: vec![],
status: ToolResultStatus::Success,
}]);
} else {
conversation_state.push_assistant_message(AssistantMessage::new_response(None, i.to_string()));
conversation_state
.push_assistant_message(AssistantMessage::new_response(None, i.to_string()), &mut database);
conversation_state.set_next_user_message(i.to_string()).await;
}
}
Expand Down Expand Up @@ -1066,7 +1075,8 @@ mod tests {

assert_conversation_state_invariants(s, i);

conversation_state.push_assistant_message(AssistantMessage::new_response(None, i.to_string()));
conversation_state
.push_assistant_message(AssistantMessage::new_response(None, i.to_string()), &mut database);
conversation_state.set_next_user_message(i.to_string()).await;
}
}
Expand Down Expand Up @@ -1134,7 +1144,8 @@ mod tests {
s.user_input_message.content
);

conversation_state.push_assistant_message(AssistantMessage::new_response(None, i.to_string()));
conversation_state
.push_assistant_message(AssistantMessage::new_response(None, i.to_string()), &mut database);
conversation_state.set_next_user_message(i.to_string()).await;
}
}
Expand Down
Loading
Loading