Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .config/nextest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[profile.default]
fail-fast = false

[profile.ci]
fail-fast = true
27 changes: 27 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Test

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
test:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4

- name: Install protoc
uses: arduino/setup-protoc@v3
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}

- name: Add wasm32-wasip2 target
run: rustup target add wasm32-wasip2

- name: Install cargo-nextest
uses: taiki-e/install-action@nextest

- name: Run tests
run: cargo nextest run --workspace --profile ci
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

330 changes: 329 additions & 1 deletion crates/ein-agent/src/agents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ mod tests {
use std::sync::Mutex;

use async_trait::async_trait;
use ein_core::types::{Choice, CompletionResponse, ToolDef, ToolResult};
use ein_core::types::{Choice, CompletionResponse, ToolDef, ToolResult, Usage};

use super::*;

Expand Down Expand Up @@ -692,4 +692,332 @@ mod tests {

assert_eq!(*tool.called_arg.lock().unwrap(), "test_val".to_string());
}

// ---------------------------------------------------------------------------
// Shared test helpers
// ---------------------------------------------------------------------------

fn tool_msg(id: &str, content: impl Into<String>) -> Message {
Message {
role: Role::Tool,
content: Some(content.into()),
tool_call_id: Some(id.to_string()),
tool_calls: None,
}
}

fn user_msg(content: impl Into<String>) -> Message {
Message {
role: Role::User,
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
}
}

fn system_msg(content: impl Into<String>) -> Message {
Message {
role: Role::System,
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
}
}

fn stop_response(content: &str) -> CompletionResponse {
CompletionResponse {
choices: vec![Choice {
index: None,
finish_reason: FinishReason::Stop,
message: Message {
role: Role::Assistant,
content: Some(content.to_string()),
tool_calls: None,
tool_call_id: None,
},
}],
usage: None,
error: None,
}
}

// ---------------------------------------------------------------------------
// truncate_old_tool_results — tested directly (private method, same file)
// ---------------------------------------------------------------------------

const TEST_THRESHOLD: usize = 50;
const TEST_WINDOW: usize = 2;

#[test]
fn truncate_old_tool_results_replaces_large_stale_content() {
let large = "x".repeat(TEST_THRESHOLD + 1);
let history = vec![
tool_msg("t1", &large),
tool_msg("t2", &large),
user_msg("recent 1"),
user_msg("recent 2"),
];

let mut agent = Agent::builder(basic_test_client())
.num_recent_messages(TEST_WINDOW)
.max_tool_result_chars(TEST_THRESHOLD)
.with_message_history(history)
.build();

agent.truncate_old_tool_results();

let msgs = agent.messages();
assert!(
msgs[0].content.as_deref().unwrap_or("").starts_with("[Tool result truncated:"),
"old large tool result must be truncated"
);
assert!(
msgs[1].content.as_deref().unwrap_or("").starts_with("[Tool result truncated:"),
"old large tool result must be truncated"
);
assert_eq!(msgs[2].content.as_deref(), Some("recent 1"));
assert_eq!(msgs[3].content.as_deref(), Some("recent 2"));
}

#[test]
fn truncate_old_tool_results_keeps_recent_messages_intact() {
let large = "x".repeat(TEST_THRESHOLD + 1);
// All 3 messages are within the window of 3 — none should be truncated.
let history = vec![
tool_msg("t1", &large),
tool_msg("t2", &large),
tool_msg("t3", &large),
];

let mut agent = Agent::builder(basic_test_client())
.num_recent_messages(3)
.max_tool_result_chars(TEST_THRESHOLD)
.with_message_history(history)
.build();

agent.truncate_old_tool_results();

for msg in agent.messages() {
assert!(
!msg.content.as_deref().unwrap_or("").starts_with("[Tool result truncated:"),
"recent messages must not be truncated"
);
}
}

#[test]
fn truncate_old_tool_results_ignores_non_tool_messages() {
let large = "x".repeat(TEST_THRESHOLD + 1);
let history = vec![
user_msg(&large),
system_msg(&large),
tool_msg("t1", "small"),
];

let mut agent = Agent::builder(basic_test_client())
.num_recent_messages(TEST_WINDOW)
.max_tool_result_chars(TEST_THRESHOLD)
.with_message_history(history)
.build();

agent.truncate_old_tool_results();

let msgs = agent.messages();
assert_eq!(msgs[0].content.as_deref(), Some(large.as_str()), "User must not be truncated");
assert_eq!(msgs[1].content.as_deref(), Some(large.as_str()), "System must not be truncated");
}

#[test]
fn truncate_old_tool_results_skips_content_at_threshold() {
// content length == threshold is NOT truncated (condition is strictly >)
let at_threshold = "x".repeat(TEST_THRESHOLD);
let history = vec![
tool_msg("t1", &at_threshold),
user_msg("recent 1"),
user_msg("recent 2"),
];

let mut agent = Agent::builder(basic_test_client())
.num_recent_messages(TEST_WINDOW)
.max_tool_result_chars(TEST_THRESHOLD)
.with_message_history(history)
.build();

agent.truncate_old_tool_results();

assert_eq!(
agent.messages()[0].content.as_deref(),
Some(at_threshold.as_str()),
"content exactly at threshold must not be truncated"
);
}

// ---------------------------------------------------------------------------
// compact_history
// ---------------------------------------------------------------------------

#[tokio::test]
async fn compact_history_returns_empty_when_no_user_messages() {
let mut agent = Agent::builder(basic_test_client())
.with_message_history(vec![system_msg("you are helpful")])
.build();

let result = agent.compact_history().await.unwrap();
assert_eq!(result, "", "nothing to compact without user messages");
}

#[tokio::test]
async fn compact_history_replaces_history_with_system_plus_summary() {
let summary = "Goals discussed, files modified, current state.";
let mut agent = Agent::builder(BasicTestModelClient {
response: stop_response(summary),
})
.with_message_history(vec![system_msg("sys"), user_msg("do stuff")])
.build();

let returned = agent.compact_history().await.unwrap();
assert_eq!(returned, summary);

let msgs = agent.messages();
assert_eq!(msgs.len(), 2, "original system + new summary system");
assert!(matches!(msgs[0].role, Role::System));
assert_eq!(msgs[0].content.as_deref(), Some("sys"));
assert!(matches!(msgs[1].role, Role::System));
assert!(msgs[1].content.as_deref().unwrap_or("").contains(summary));
}

#[tokio::test]
async fn compact_history_broadcasts_content_delta_event() {
use std::sync::Arc;

let summary = "Detailed summary.";
let captured: Arc<Mutex<Vec<AgentEvent>>> = Arc::new(Mutex::new(Vec::new()));
let cap = captured.clone();

let mut agent = Agent::builder(BasicTestModelClient {
response: stop_response(summary),
})
.with_event_handler(move |event| {
let cap = cap.clone();
async move { cap.lock().unwrap().push(event); }
})
.with_message_history(vec![user_msg("do stuff")])
.build();

agent.compact_history().await.unwrap();

let events = captured.lock().unwrap();
let deltas: Vec<&str> = events
.iter()
.filter_map(|e| {
if let AgentEvent::ContentDelta(t) = e { Some(t.as_str()) } else { None }
})
.collect();
assert_eq!(deltas, vec![summary]);
}

// ---------------------------------------------------------------------------
// chat error paths
// ---------------------------------------------------------------------------

#[tokio::test]
async fn chat_returns_error_on_api_error_response() {
let mut agent = Agent::builder(BasicTestModelClient {
response: CompletionResponse {
choices: vec![],
usage: None,
error: Some(serde_json::json!({"message": "insufficient credits"})),
},
})
.build();

let err = agent.chat("prompt").await.unwrap_err();
assert!(matches!(err, AgentError::ModelClient(_)));
assert!(err.to_string().contains("insufficient credits"));
}

#[tokio::test]
async fn chat_returns_error_on_unsupported_finish_reason() {
let mut agent = Agent::builder(BasicTestModelClient {
response: CompletionResponse {
choices: vec![Choice {
index: None,
finish_reason: FinishReason::Unsupported,
message: Message {
role: Role::Assistant,
content: None,
tool_calls: None,
tool_call_id: None,
},
}],
usage: None,
error: None,
},
})
.build();

let err = agent.chat("prompt").await.unwrap_err();
assert!(matches!(err, AgentError::UnsupportedFinishReason(_)));
}

// ---------------------------------------------------------------------------
// Token usage events and clear
// ---------------------------------------------------------------------------

#[tokio::test]
async fn chat_emits_token_usage_event() {
use std::sync::Arc;

let captured: Arc<Mutex<Vec<AgentEvent>>> = Arc::new(Mutex::new(Vec::new()));
let cap = captured.clone();

let mut agent = Agent::builder(BasicTestModelClient {
response: CompletionResponse {
choices: vec![Choice {
index: None,
finish_reason: FinishReason::Stop,
message: Message {
role: Role::Assistant,
content: Some("done".to_string()),
tool_calls: None,
tool_call_id: None,
},
}],
usage: Some(Usage {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
}),
error: None,
},
})
.with_event_handler(move |event| {
let cap = cap.clone();
async move { cap.lock().unwrap().push(event); }
})
.build();

agent.chat("hello").await.unwrap();

let events = captured.lock().unwrap();
let usage = events.iter().find_map(|e| {
if let AgentEvent::TokenUsage { prompt_tokens, completion_tokens, total_tokens } = e {
Some((*prompt_tokens, *completion_tokens, *total_tokens))
} else {
None
}
});
assert_eq!(usage, Some((10, 5, 15)), "TokenUsage event must carry correct totals");
}

#[tokio::test]
async fn clear_messages_empties_history() {
let mut agent = Agent::builder(basic_test_client())
.with_message_history(vec![system_msg("sys"), user_msg("hello")])
.build();

assert!(!agent.messages().is_empty());
agent.clear_messages();
assert!(agent.messages().is_empty());
}
}
Loading
Loading