From 20f2b368a768351b358ead0ce492b7e03e0ff860 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Tue, 17 Mar 2026 10:52:23 -0400 Subject: [PATCH 1/4] feat: add EvaluationPlugin for agent invocation evaluation and retry Introduce an EvaluationPlugin that hooks into agent invocations to evaluate outputs against expected results and automatically retries with improved system prompts on failure. - Add EvaluationPlugin class that wraps agent __call__ to intercept invocations, run evaluators, and retry with LLM-suggested prompt improvements when evaluations fail - Add improvement suggestion prompt template for generating better system prompts based on evaluation feedback - Add comprehensive test suite covering plugin initialization, wrapping, evaluation execution, retry logic, and edge cases - Update ruff config to ignore line-length in plugin prompt templates --- pyproject.toml | 1 + src/strands_evals/plugins/__init__.py | 3 + .../plugins/evaluation_plugin.py | 144 ++++++ .../plugins/prompt_templates/__init__.py | 0 .../improvement_suggestion.py | 37 ++ tests/strands_evals/plugins/__init__.py | 0 .../plugins/test_evaluation_plugin.py | 466 ++++++++++++++++++ 7 files changed, 651 insertions(+) create mode 100644 src/strands_evals/plugins/__init__.py create mode 100644 src/strands_evals/plugins/evaluation_plugin.py create mode 100644 src/strands_evals/plugins/prompt_templates/__init__.py create mode 100644 src/strands_evals/plugins/prompt_templates/improvement_suggestion.py create mode 100644 tests/strands_evals/plugins/__init__.py create mode 100644 tests/strands_evals/plugins/test_evaluation_plugin.py diff --git a/pyproject.toml b/pyproject.toml index 511b56ee..fdf763d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -128,6 +128,7 @@ select = [ [tool.ruff.lint.per-file-ignores] "src/strands_evals/evaluators/prompt_templates/*" = ["E501"] # line-length "src/strands_evals/generators/prompt_template/*" = ["E501"] # line-length +"src/strands_evals/plugins/prompt_templates/*" = ["E501"] # line-length [tool.mypy] # Disable strict checks that cause false positives with Generic classes diff --git a/src/strands_evals/plugins/__init__.py b/src/strands_evals/plugins/__init__.py new file mode 100644 index 00000000..ca39e350 --- /dev/null +++ b/src/strands_evals/plugins/__init__.py @@ -0,0 +1,3 @@ +from .evaluation_plugin import EvaluationPlugin + +__all__ = ["EvaluationPlugin"] diff --git a/src/strands_evals/plugins/evaluation_plugin.py b/src/strands_evals/plugins/evaluation_plugin.py new file mode 100644 index 00000000..b23034fe --- /dev/null +++ b/src/strands_evals/plugins/evaluation_plugin.py @@ -0,0 +1,144 @@ +"""Plugin that evaluates agent invocations and retries with improvements on failure.""" + +import logging +from typing import Any, Union, cast + +from pydantic import BaseModel +from strands import Agent +from strands.models import Model +from strands.plugins.plugin import Plugin + +from strands_evals.evaluators.evaluator import Evaluator +from strands_evals.plugins.prompt_templates.improvement_suggestion import ( + IMPROVEMENT_SYSTEM_PROMPT, + compose_improvement_prompt, +) +from strands_evals.types.evaluation import EvaluationData, EvaluationOutput + +logger = logging.getLogger(__name__) + + +class ImprovementSuggestion(BaseModel): + """Structured output from the improvement suggestion LLM.""" + + reasoning: str + system_prompt: str + + +class EvaluationPlugin(Plugin): + """Evaluates agent output after each invocation and retries with improved system prompts on failure.""" + + @property + def name(self) -> str: + return "strands-evals" + + def __init__( + self, + evaluators: list[Evaluator], + max_retries: int = 1, + expected_output: Any = None, + expected_trajectory: list[Any] | None = None, + model: Union[Model, str, None] = None, + ): + self._evaluators = evaluators + self._max_retries = max_retries + self._expected_output = expected_output + self._expected_trajectory = expected_trajectory + self._model = model + self._agent: Any = None + + def init_agent(self, agent: Any) -> None: + self._agent = agent + original_call = agent.__class__.__call__ + plugin = self + + def wrapped_call(self_agent: Any, prompt: Any = None, **kwargs: Any) -> Any: + return plugin._invoke_with_evaluation(self_agent, original_call, prompt, **kwargs) + + wrapped_class = type( + agent.__class__.__name__, + (agent.__class__,), + {"__call__": wrapped_call}, + ) + agent.__class__ = wrapped_class + + def _invoke_with_evaluation(self, agent: Any, original_call: Any, prompt: Any, **kwargs: Any) -> Any: + original_system_prompt = agent.system_prompt + original_messages = list(agent.messages) + invocation_state = kwargs.get("invocation_state") or {} + + for attempt in range(1 + self._max_retries): + if attempt > 0: + agent.messages = list(original_messages) + + result = original_call(agent, prompt, **kwargs) + + evaluation_data = self._build_evaluation_data(prompt, result, invocation_state) + outputs = self._run_evaluators(evaluation_data) + all_pass = all(o.test_pass for o in outputs) + + if all_pass or attempt == self._max_retries: + break + + logger.debug( + "attempt=<%s>, evaluation_pass=<%s> | evaluation failed, generating improvements", + attempt + 1, + all_pass, + ) + print("begin retrying============================") + print('result') + print(result) + print('outputs') + print(outputs) + + expected_output = evaluation_data.expected_output + suggestion = self._suggest_improvements(prompt, str(result), outputs, agent.system_prompt, expected_output) + print('suggestion') + print(suggestion) + agent.system_prompt = suggestion + + agent.system_prompt = original_system_prompt + return result + + def _build_evaluation_data(self, prompt: Any, result: Any, invocation_state: dict) -> EvaluationData: + expected_output = invocation_state.get("expected_output", self._expected_output) + expected_trajectory = invocation_state.get("expected_trajectory", self._expected_trajectory) + + return EvaluationData( + input=prompt, + actual_output=str(result), + expected_output=expected_output, + expected_trajectory=expected_trajectory, + ) + + def _suggest_improvements( + self, + prompt: Any, + actual_output: str, + outputs: list[EvaluationOutput], + current_system_prompt: str | None, + expected_output: Any = None, + ) -> str: + failure_reasons = [o.reason for o in outputs if not o.test_pass and o.reason] + improvement_prompt = compose_improvement_prompt( + user_prompt=str(prompt), + actual_output=actual_output, + failure_reasons=failure_reasons, + current_system_prompt=current_system_prompt, + expected_output=str(expected_output) if expected_output is not None else None, + ) + suggestion_agent = Agent(model=self._model, system_prompt=IMPROVEMENT_SYSTEM_PROMPT, callback_handler=None) + result = suggestion_agent(improvement_prompt, structured_output_model=ImprovementSuggestion) + suggestion = cast(ImprovementSuggestion, result.structured_output) + return suggestion.system_prompt + + def _run_evaluators(self, evaluation_data: EvaluationData) -> list[EvaluationOutput]: + all_outputs: list[EvaluationOutput] = [] + for evaluator in self._evaluators: + try: + outputs = evaluator.evaluate(evaluation_data) + all_outputs.extend(outputs) + except Exception: + logger.exception("evaluator=<%s> | evaluator raised an exception", type(evaluator).__name__) + all_outputs.append(EvaluationOutput(score=0.0, test_pass=False, reason="evaluator raised an exception")) + return all_outputs diff --git a/src/strands_evals/plugins/prompt_templates/__init__.py b/src/strands_evals/plugins/prompt_templates/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/strands_evals/plugins/prompt_templates/improvement_suggestion.py b/src/strands_evals/plugins/prompt_templates/improvement_suggestion.py new file mode 100644 index 00000000..ac05a22c --- /dev/null +++ b/src/strands_evals/plugins/prompt_templates/improvement_suggestion.py @@ -0,0 +1,37 @@ +IMPROVEMENT_SYSTEM_PROMPT = """You are an expert at analyzing AI agent failures and suggesting system prompt improvements. + +Given an agent's current system prompt, the user's request, the agent's output, and evaluation failures, suggest a modified system prompt that addresses the failures while preserving the agent's core capabilities. + +Focus on: +- Adding specific instructions that address the evaluation failure reasons +- Preserving existing useful instructions from the current system prompt +- Being concise and actionable +- Not changing the fundamental purpose of the agent + +Return the complete improved system prompt that should replace the current one.""" + + +def compose_improvement_prompt( + user_prompt: str, + actual_output: str, + failure_reasons: list[str], + current_system_prompt: str | None, + expected_output: str | None = None, +) -> str: + parts = [] + + if current_system_prompt: + parts.append(f"{current_system_prompt}") + else: + parts.append("No system prompt set") + + parts.append(f"{user_prompt}") + parts.append(f"{actual_output}") + + if expected_output is not None: + parts.append(f"{expected_output}") + + reasons_text = "\n".join(f"- {reason}" for reason in failure_reasons) + parts.append(f"\n{reasons_text}\n") + + return "\n\n".join(parts) diff --git a/tests/strands_evals/plugins/__init__.py b/tests/strands_evals/plugins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/strands_evals/plugins/test_evaluation_plugin.py b/tests/strands_evals/plugins/test_evaluation_plugin.py new file mode 100644 index 00000000..a8e9ab83 --- /dev/null +++ b/tests/strands_evals/plugins/test_evaluation_plugin.py @@ -0,0 +1,466 @@ +from unittest.mock import Mock, patch + +import pytest + +from strands_evals.plugins import EvaluationPlugin +from strands_evals.plugins.evaluation_plugin import ImprovementSuggestion +from strands_evals.types import EvaluationData, EvaluationOutput + + +@pytest.fixture +def mock_evaluator(): + evaluator = Mock() + return evaluator + + +class FakeAgent: + """Minimal agent class for testing __class__ swap without real Agent internals.""" + + def __init__(self): + self.system_prompt = "original prompt" + self.messages: list = [] + self._result = Mock() + self._result.__str__ = Mock(return_value="mock output") + self.call_count = 0 + + def __call__(self, prompt=None, **kwargs): + self.call_count += 1 + return self._result + + +@pytest.fixture +def mock_agent(): + return FakeAgent() + + +def test_plugin_name(mock_evaluator): + plugin = EvaluationPlugin(evaluators=[mock_evaluator]) + assert plugin.name == "strands-evals" + + +def test_init_with_defaults(mock_evaluator): + plugin = EvaluationPlugin(evaluators=[mock_evaluator]) + assert plugin._evaluators == [mock_evaluator] + assert plugin._max_retries == 1 + assert plugin._expected_output is None + assert plugin._expected_trajectory is None + assert plugin._model is None + + +def test_init_with_custom_values(mock_evaluator): + plugin = EvaluationPlugin( + evaluators=[mock_evaluator], + max_retries=3, + expected_output="expected", + expected_trajectory=["step1", "step2"], + model="us.anthropic.claude-sonnet-4-20250514-v1:0", + ) + assert plugin._evaluators == [mock_evaluator] + assert plugin._max_retries == 3 + assert plugin._expected_output == "expected" + assert plugin._expected_trajectory == ["step1", "step2"] + assert plugin._model == "us.anthropic.claude-sonnet-4-20250514-v1:0" + + +# --- Step 2: init_agent + __class__ swap --- + + +def test_init_agent_wraps_call(mock_evaluator, mock_agent): + """After init_agent, agent.__class__ should be a subclass of the original.""" + original_class = mock_agent.__class__ + plugin = EvaluationPlugin(evaluators=[mock_evaluator]) + plugin.init_agent(mock_agent) + assert mock_agent.__class__ is not original_class + assert issubclass(mock_agent.__class__, original_class) + + +def test_init_agent_preserves_isinstance(mock_evaluator, mock_agent): + """isinstance(agent, OriginalClass) should still be True after wrapping.""" + plugin = EvaluationPlugin(evaluators=[mock_evaluator]) + plugin.init_agent(mock_agent) + assert isinstance(mock_agent, FakeAgent) + + +def test_wrapped_call_invokes_original(mock_evaluator, mock_agent): + """Calling agent(prompt) after wrapping should still invoke the original agent logic.""" + mock_evaluator.evaluate.return_value = [Mock(test_pass=True)] + plugin = EvaluationPlugin(evaluators=[mock_evaluator], max_retries=0) + plugin.init_agent(mock_agent) + + result = mock_agent("test prompt") + + assert result is not None + + +# --- Step 3: Evaluation execution on invocation --- + + +def test_evaluators_run_after_invocation(mock_agent): + """Evaluators should be called with EvaluationData after invocation.""" + evaluator = Mock() + evaluator.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True, reason="good")] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=0) + plugin.init_agent(mock_agent) + + mock_agent("What is 2+2?") + + evaluator.evaluate.assert_called_once() + eval_data = evaluator.evaluate.call_args[0][0] + assert isinstance(eval_data, EvaluationData) + assert eval_data.input == "What is 2+2?" + assert eval_data.actual_output == str(mock_agent._result) + + +def test_evaluation_data_uses_constructor_expected(mock_agent): + """EvaluationData should use expected values from constructor.""" + evaluator = Mock() + evaluator.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True)] + plugin = EvaluationPlugin( + evaluators=[evaluator], + max_retries=0, + expected_output="4", + expected_trajectory=["calculator"], + ) + plugin.init_agent(mock_agent) + + mock_agent("What is 2+2?") + + eval_data = evaluator.evaluate.call_args[0][0] + assert eval_data.expected_output == "4" + assert eval_data.expected_trajectory == ["calculator"] + + +def test_evaluation_data_uses_invocation_state_expected(mock_agent): + """invocation_state expected values should override constructor values.""" + evaluator = Mock() + evaluator.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True)] + plugin = EvaluationPlugin( + evaluators=[evaluator], + max_retries=0, + expected_output="constructor_value", + ) + plugin.init_agent(mock_agent) + + mock_agent("prompt", invocation_state={"expected_output": "state_value"}) + + eval_data = evaluator.evaluate.call_args[0][0] + assert eval_data.expected_output == "state_value" + + +def test_passing_evaluation_returns_immediately(mock_agent): + """When evaluations pass, result should be returned without retry.""" + evaluator = Mock() + evaluator.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True)] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=3) + plugin.init_agent(mock_agent) + + result = mock_agent("prompt") + + assert result is mock_agent._result + evaluator.evaluate.assert_called_once() + + +# --- Step 4: Retry on failure --- + + +def test_retry_on_failure(mock_agent): + """Agent should be re-invoked when evaluation fails.""" + evaluator = Mock() + evaluator.evaluate.side_effect = [ + [EvaluationOutput(score=0.0, test_pass=False, reason="bad")], + [EvaluationOutput(score=1.0, test_pass=True, reason="good")], + ] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) + plugin._suggest_improvements = Mock(return_value="improved prompt") + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + assert mock_agent.call_count == 2 + assert evaluator.evaluate.call_count == 2 + + +def test_max_retries_respected(mock_agent): + """Should stop retrying after max_retries attempts.""" + evaluator = Mock() + evaluator.evaluate.return_value = [EvaluationOutput(score=0.0, test_pass=False, reason="always bad")] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=2) + plugin._suggest_improvements = Mock(return_value="improved prompt") + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + # 1 initial + 2 retries = 3 total calls + assert mock_agent.call_count == 3 + assert evaluator.evaluate.call_count == 3 + + +def test_max_retries_zero_no_retry(mock_agent): + """No retry when max_retries=0.""" + evaluator = Mock() + evaluator.evaluate.return_value = [EvaluationOutput(score=0.0, test_pass=False)] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=0) + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + assert mock_agent.call_count == 1 + assert evaluator.evaluate.call_count == 1 + + +def test_messages_reset_between_retries(mock_agent): + """Agent messages should be restored to pre-invocation state before each retry.""" + messages_during_calls = [] + + original_messages = ["pre-existing"] + mock_agent.messages = list(original_messages) + + original_call = mock_agent.__class__.__call__ + + def tracking_call(self, prompt=None, **kwargs): + messages_during_calls.append(list(self.messages)) + self.messages.append(f"response-{self.call_count}") + return original_call(self, prompt, **kwargs) + + mock_agent.__class__.__call__ = tracking_call + + evaluator = Mock() + evaluator.evaluate.return_value = [EvaluationOutput(score=0.0, test_pass=False)] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) + plugin._suggest_improvements = Mock(return_value="improved prompt") + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + # Both calls should start with the original messages + assert messages_during_calls[0] == original_messages + assert messages_during_calls[1] == original_messages + + +def test_system_prompt_restored_after_all_attempts(mock_agent): + """Original system prompt should be restored after retries complete.""" + original_prompt = mock_agent.system_prompt + evaluator = Mock() + evaluator.evaluate.return_value = [EvaluationOutput(score=0.0, test_pass=False)] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) + plugin._suggest_improvements = Mock(return_value="improved prompt") + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + assert mock_agent.system_prompt == original_prompt + + +# --- Step 5: Improvement suggestion generation --- + + +@patch("strands_evals.plugins.evaluation_plugin.Agent") +def test_improvement_suggestion_called_on_failure(mock_agent_class, mock_agent): + """LLM should be called to generate suggestions when evaluation fails.""" + mock_suggestion_agent = Mock() + mock_suggestion_result = Mock() + mock_suggestion_result.structured_output = ImprovementSuggestion( + reasoning="Output was wrong", system_prompt="Be more careful" + ) + mock_suggestion_agent.return_value = mock_suggestion_result + mock_agent_class.return_value = mock_suggestion_agent + + evaluator = Mock() + evaluator.evaluate.side_effect = [ + [EvaluationOutput(score=0.0, test_pass=False, reason="incorrect")], + [EvaluationOutput(score=1.0, test_pass=True, reason="correct")], + ] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1, model="test-model") + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + mock_agent_class.assert_called_once() + mock_suggestion_agent.assert_called_once() + + +@patch("strands_evals.plugins.evaluation_plugin.Agent") +def test_improvement_suggestion_prompt_contains_failures(mock_agent_class, mock_agent): + """Improvement prompt should include evaluation failure reasons.""" + mock_suggestion_agent = Mock() + mock_suggestion_result = Mock() + mock_suggestion_result.structured_output = ImprovementSuggestion(reasoning="Needs fix", system_prompt="improved") + mock_suggestion_agent.return_value = mock_suggestion_result + mock_agent_class.return_value = mock_suggestion_agent + + evaluator = Mock() + evaluator.evaluate.side_effect = [ + [EvaluationOutput(score=0.0, test_pass=False, reason="answer was factually wrong")], + [EvaluationOutput(score=1.0, test_pass=True)], + ] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + prompt_arg = mock_suggestion_agent.call_args[0][0] + assert "answer was factually wrong" in prompt_arg + assert "original prompt" in prompt_arg + + +@patch("strands_evals.plugins.evaluation_plugin.Agent") +def test_improvement_suggestion_prompt_contains_expected_output(mock_agent_class, mock_agent): + """Improvement prompt should include expected_output when available.""" + mock_suggestion_agent = Mock() + mock_suggestion_result = Mock() + mock_suggestion_result.structured_output = ImprovementSuggestion(reasoning="fix", system_prompt="improved") + mock_suggestion_agent.return_value = mock_suggestion_result + mock_agent_class.return_value = mock_suggestion_agent + + evaluator = Mock() + evaluator.evaluate.side_effect = [ + [EvaluationOutput(score=0.0, test_pass=False, reason="does not match")], + [EvaluationOutput(score=1.0, test_pass=True)], + ] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1, expected_output="Paris") + plugin.init_agent(mock_agent) + + mock_agent("What is the capital of France?") + + prompt_arg = mock_suggestion_agent.call_args[0][0] + assert "Paris" in prompt_arg + assert "ExpectedOutput" in prompt_arg + + +@patch("strands_evals.plugins.evaluation_plugin.Agent") +def test_improvement_suggestion_prompt_omits_expected_output_when_none(mock_agent_class, mock_agent): + """Improvement prompt should not include ExpectedOutput tag when no expected_output is set.""" + mock_suggestion_agent = Mock() + mock_suggestion_result = Mock() + mock_suggestion_result.structured_output = ImprovementSuggestion(reasoning="fix", system_prompt="improved") + mock_suggestion_agent.return_value = mock_suggestion_result + mock_agent_class.return_value = mock_suggestion_agent + + evaluator = Mock() + evaluator.evaluate.side_effect = [ + [EvaluationOutput(score=0.0, test_pass=False, reason="bad")], + [EvaluationOutput(score=1.0, test_pass=True)], + ] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + prompt_arg = mock_suggestion_agent.call_args[0][0] + assert "ExpectedOutput" not in prompt_arg + + +@patch("strands_evals.plugins.evaluation_plugin.Agent") +def test_improved_system_prompt_applied_before_retry(mock_agent_class, mock_agent): + """Agent system prompt should be updated with suggestion before retry.""" + mock_suggestion_agent = Mock() + mock_suggestion_result = Mock() + mock_suggestion_result.structured_output = ImprovementSuggestion( + reasoning="Needs specificity", system_prompt="Always answer with the city name only" + ) + mock_suggestion_agent.return_value = mock_suggestion_result + mock_agent_class.return_value = mock_suggestion_agent + + system_prompts_during_calls = [] + original_call = mock_agent.__class__.__call__ + + def tracking_call(self, prompt=None, **kwargs): + system_prompts_during_calls.append(self.system_prompt) + return original_call(self, prompt, **kwargs) + + mock_agent.__class__.__call__ = tracking_call + + evaluator = Mock() + evaluator.evaluate.side_effect = [ + [EvaluationOutput(score=0.0, test_pass=False, reason="bad")], + [EvaluationOutput(score=1.0, test_pass=True)], + ] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + assert system_prompts_during_calls[0] == "original prompt" + assert system_prompts_during_calls[1] == "Always answer with the city name only" + + +# --- Step 7: Multiple evaluators + edge cases --- + + +def test_multiple_evaluators_all_must_pass(mock_agent): + """All evaluators must pass for the invocation to be considered successful.""" + evaluator1 = Mock() + evaluator1.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True)] + evaluator2 = Mock() + evaluator2.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True)] + plugin = EvaluationPlugin(evaluators=[evaluator1, evaluator2], max_retries=0) + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + evaluator1.evaluate.assert_called_once() + evaluator2.evaluate.assert_called_once() + assert mock_agent.call_count == 1 + + +def test_partial_evaluator_failure_triggers_retry(mock_agent): + """If any evaluator fails, retry should be triggered.""" + evaluator1 = Mock() + evaluator1.evaluate.return_value = [EvaluationOutput(score=1.0, test_pass=True)] + evaluator2 = Mock() + evaluator2.evaluate.side_effect = [ + [EvaluationOutput(score=0.0, test_pass=False, reason="failed")], + [EvaluationOutput(score=1.0, test_pass=True)], + ] + plugin = EvaluationPlugin(evaluators=[evaluator1, evaluator2], max_retries=1) + plugin._suggest_improvements = Mock(return_value="improved") + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + assert mock_agent.call_count == 2 + + +def test_evaluator_exception_recorded_as_failure(mock_agent): + """Evaluator exceptions should be caught and treated as failures.""" + evaluator = Mock() + evaluator.evaluate.side_effect = [ + RuntimeError("evaluator crashed"), + [EvaluationOutput(score=1.0, test_pass=True)], + ] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) + plugin._suggest_improvements = Mock(return_value="improved") + plugin.init_agent(mock_agent) + + mock_agent("prompt") + + assert mock_agent.call_count == 2 + + +def test_returns_final_attempt_result(mock_agent): + """Should return the result from the final attempt.""" + results = [Mock(__str__=Mock(return_value="first")), Mock(__str__=Mock(return_value="second"))] + call_idx = [0] + + original_call = mock_agent.__class__.__call__ + + def multi_result_call(self, prompt=None, **kwargs): + idx = call_idx[0] + call_idx[0] += 1 + original_call(self, prompt, **kwargs) + return results[idx] + + mock_agent.__class__.__call__ = multi_result_call + + evaluator = Mock() + evaluator.evaluate.side_effect = [ + [EvaluationOutput(score=0.0, test_pass=False)], + [EvaluationOutput(score=1.0, test_pass=True)], + ] + plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) + plugin._suggest_improvements = Mock(return_value="improved") + plugin.init_agent(mock_agent) + + result = mock_agent("prompt") + + assert result is results[1] From 2e8d9437b7416905112731cfc78ea15b32a89d73 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Wed, 18 Mar 2026 10:22:34 -0400 Subject: [PATCH 2/4] rm print --- src/strands_evals/plugins/evaluation_plugin.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/strands_evals/plugins/evaluation_plugin.py b/src/strands_evals/plugins/evaluation_plugin.py index b23034fe..bcba5215 100644 --- a/src/strands_evals/plugins/evaluation_plugin.py +++ b/src/strands_evals/plugins/evaluation_plugin.py @@ -85,16 +85,9 @@ def _invoke_with_evaluation(self, agent: Any, original_call: Any, prompt: Any, * attempt + 1, all_pass, ) - print("begin retrying============================") - print('result') - print(result) - print('outputs') - print(outputs) expected_output = evaluation_data.expected_output suggestion = self._suggest_improvements(prompt, str(result), outputs, agent.system_prompt, expected_output) - print('suggestion') - print(suggestion) agent.system_prompt = suggestion agent.system_prompt = original_system_prompt From 9aaa441541d04d46db5a718dcf0cb2496f5f872d Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Wed, 18 Mar 2026 10:44:15 -0400 Subject: [PATCH 3/4] feat: add docstrings and return structured ImprovementSuggestion type - Add comprehensive docstrings to all methods in EvaluationPlugin - Change _suggest_improvements to return ImprovementSuggestion instead of str - Add debug logging with reasoning when applying improved system prompt - Replace Union[X, Y] with modern X | Y syntax - Use dict default in kwargs.get() instead of --- .../plugins/evaluation_plugin.py | 82 +++++++++++++++++-- .../plugins/test_evaluation_plugin.py | 28 +++++-- 2 files changed, 96 insertions(+), 14 deletions(-) diff --git a/src/strands_evals/plugins/evaluation_plugin.py b/src/strands_evals/plugins/evaluation_plugin.py index bcba5215..c59a8f99 100644 --- a/src/strands_evals/plugins/evaluation_plugin.py +++ b/src/strands_evals/plugins/evaluation_plugin.py @@ -1,7 +1,7 @@ """Plugin that evaluates agent invocations and retries with improvements on failure.""" import logging -from typing import Any, Union, cast +from typing import Any, cast from pydantic import BaseModel from strands import Agent @@ -38,8 +38,20 @@ def __init__( max_retries: int = 1, expected_output: Any = None, expected_trajectory: list[Any] | None = None, - model: Union[Model, str, None] = None, + model: Model | str | None = None, ): + """Initialize the evaluation plugin. + + Args: + evaluators: Evaluators to run against agent output after each invocation. + max_retries: Maximum number of retry attempts when evaluation fails. + expected_output: Default expected output for evaluation. Can be overridden per-invocation + via ``invocation_state``. + expected_trajectory: Default expected trajectory for evaluation. Can be overridden + per-invocation via ``invocation_state``. + model: Model used by the improvement suggestion agent. Accepts a Model instance, + a model ID string, or None to use the default. + """ self._evaluators = evaluators self._max_retries = max_retries self._expected_output = expected_output @@ -48,6 +60,14 @@ def __init__( self._agent: Any = None def init_agent(self, agent: Any) -> None: + """Wrap the agent's ``__call__`` to intercept invocations for evaluation and retry. + + Creates a dynamic subclass of the agent's class with a wrapped ``__call__`` that runs + evaluators after each invocation and retries with an improved system prompt on failure. + + Args: + agent: The agent instance whose invocations will be evaluated. + """ self._agent = agent original_call = agent.__class__.__call__ plugin = self @@ -63,9 +83,22 @@ def wrapped_call(self_agent: Any, prompt: Any = None, **kwargs: Any) -> Any: agent.__class__ = wrapped_class def _invoke_with_evaluation(self, agent: Any, original_call: Any, prompt: Any, **kwargs: Any) -> Any: + """Run the agent, evaluate output, and retry with an improved system prompt on failure. + + Restores the original system prompt and messages after all attempts complete. + + Args: + agent: The agent instance being invoked. + original_call: The unwrapped ``__call__`` method. + prompt: The user prompt passed to the agent. + **kwargs: Additional keyword arguments forwarded to the agent call. + + Returns: + The result from the last agent invocation attempt. + """ original_system_prompt = agent.system_prompt original_messages = list(agent.messages) - invocation_state = kwargs.get("invocation_state") or {} + invocation_state = kwargs.get("invocation_state", {}) for attempt in range(1 + self._max_retries): if attempt > 0: @@ -88,12 +121,25 @@ def _invoke_with_evaluation(self, agent: Any, original_call: Any, prompt: Any, * expected_output = evaluation_data.expected_output suggestion = self._suggest_improvements(prompt, str(result), outputs, agent.system_prompt, expected_output) - agent.system_prompt = suggestion + logger.debug( + "attempt=<%s>, reasoning=<%s> | applying improved system prompt", attempt + 1, suggestion.reasoning + ) + agent.system_prompt = suggestion.system_prompt agent.system_prompt = original_system_prompt return result def _build_evaluation_data(self, prompt: Any, result: Any, invocation_state: dict) -> EvaluationData: + """Assemble evaluation data from the invocation context. + + Args: + prompt: The user prompt. + result: The agent's output. + invocation_state: Per-invocation overrides for expected values. + + Returns: + An EvaluationData instance ready for evaluator consumption. + """ expected_output = invocation_state.get("expected_output", self._expected_output) expected_trajectory = invocation_state.get("expected_trajectory", self._expected_trajectory) @@ -111,7 +157,19 @@ def _suggest_improvements( outputs: list[EvaluationOutput], current_system_prompt: str | None, expected_output: Any = None, - ) -> str: + ) -> ImprovementSuggestion: + """Ask an LLM to suggest an improved system prompt based on evaluation failures. + + Args: + prompt: The original user prompt. + actual_output: The agent's output as a string. + outputs: Evaluation outputs from the failed attempt. + current_system_prompt: The agent's current system prompt. + expected_output: The expected output, if available. + + Returns: + An ImprovementSuggestion containing the reasoning and a revised system prompt. + """ failure_reasons = [o.reason for o in outputs if not o.test_pass and o.reason] improvement_prompt = compose_improvement_prompt( user_prompt=str(prompt), @@ -122,10 +180,20 @@ def _suggest_improvements( ) suggestion_agent = Agent(model=self._model, system_prompt=IMPROVEMENT_SYSTEM_PROMPT, callback_handler=None) result = suggestion_agent(improvement_prompt, structured_output_model=ImprovementSuggestion) - suggestion = cast(ImprovementSuggestion, result.structured_output) - return suggestion.system_prompt + return cast(ImprovementSuggestion, result.structured_output) def _run_evaluators(self, evaluation_data: EvaluationData) -> list[EvaluationOutput]: + """Run all evaluators against the given evaluation data. + + Exceptions raised by individual evaluators are caught, logged, and recorded as failures + so that a single broken evaluator does not prevent the others from running. + + Args: + evaluation_data: The data to evaluate. + + Returns: + A list of evaluation outputs from all evaluators. + """ all_outputs: list[EvaluationOutput] = [] for evaluator in self._evaluators: try: diff --git a/tests/strands_evals/plugins/test_evaluation_plugin.py b/tests/strands_evals/plugins/test_evaluation_plugin.py index a8e9ab83..9533fe65 100644 --- a/tests/strands_evals/plugins/test_evaluation_plugin.py +++ b/tests/strands_evals/plugins/test_evaluation_plugin.py @@ -171,7 +171,9 @@ def test_retry_on_failure(mock_agent): [EvaluationOutput(score=1.0, test_pass=True, reason="good")], ] plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) - plugin._suggest_improvements = Mock(return_value="improved prompt") + plugin._suggest_improvements = Mock( + return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved prompt") + ) plugin.init_agent(mock_agent) mock_agent("prompt") @@ -185,7 +187,9 @@ def test_max_retries_respected(mock_agent): evaluator = Mock() evaluator.evaluate.return_value = [EvaluationOutput(score=0.0, test_pass=False, reason="always bad")] plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=2) - plugin._suggest_improvements = Mock(return_value="improved prompt") + plugin._suggest_improvements = Mock( + return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved prompt") + ) plugin.init_agent(mock_agent) mock_agent("prompt") @@ -227,7 +231,9 @@ def tracking_call(self, prompt=None, **kwargs): evaluator = Mock() evaluator.evaluate.return_value = [EvaluationOutput(score=0.0, test_pass=False)] plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) - plugin._suggest_improvements = Mock(return_value="improved prompt") + plugin._suggest_improvements = Mock( + return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved prompt") + ) plugin.init_agent(mock_agent) mock_agent("prompt") @@ -243,7 +249,9 @@ def test_system_prompt_restored_after_all_attempts(mock_agent): evaluator = Mock() evaluator.evaluate.return_value = [EvaluationOutput(score=0.0, test_pass=False)] plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) - plugin._suggest_improvements = Mock(return_value="improved prompt") + plugin._suggest_improvements = Mock( + return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved prompt") + ) plugin.init_agent(mock_agent) mock_agent("prompt") @@ -413,7 +421,9 @@ def test_partial_evaluator_failure_triggers_retry(mock_agent): [EvaluationOutput(score=1.0, test_pass=True)], ] plugin = EvaluationPlugin(evaluators=[evaluator1, evaluator2], max_retries=1) - plugin._suggest_improvements = Mock(return_value="improved") + plugin._suggest_improvements = Mock( + return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved") + ) plugin.init_agent(mock_agent) mock_agent("prompt") @@ -429,7 +439,9 @@ def test_evaluator_exception_recorded_as_failure(mock_agent): [EvaluationOutput(score=1.0, test_pass=True)], ] plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) - plugin._suggest_improvements = Mock(return_value="improved") + plugin._suggest_improvements = Mock( + return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved") + ) plugin.init_agent(mock_agent) mock_agent("prompt") @@ -458,7 +470,9 @@ def multi_result_call(self, prompt=None, **kwargs): [EvaluationOutput(score=1.0, test_pass=True)], ] plugin = EvaluationPlugin(evaluators=[evaluator], max_retries=1) - plugin._suggest_improvements = Mock(return_value="improved") + plugin._suggest_improvements = Mock( + return_value=ImprovementSuggestion(reasoning="needs improvement", system_prompt="improved") + ) plugin.init_agent(mock_agent) result = mock_agent("prompt") From e2fa1239979ebd56b16bb1201c2276328c80040d Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Wed, 18 Mar 2026 11:28:48 -0400 Subject: [PATCH 4/4] feat: add plugins module and bump strands-agents to >=1.28.0 - Export new module from the package's public API - Bump minimum strands-agents dependency from 1.0.0 to 1.28.0 to support functionality required by the plugins module --- pyproject.toml | 2 +- src/strands_evals/__init__.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fdf763d1..d557f65a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ authors = [ dependencies = [ "pydantic>=2.0.0,<3.0.0", "rich>=14.0.0,<15.0.0", - "strands-agents>=1.0.0", + "strands-agents>=1.28.0", "strands-agents-tools>=0.1.0,<1.0.0", "typing-extensions>=4.0", "opentelemetry-api>=1.20.0", diff --git a/src/strands_evals/__init__.py b/src/strands_evals/__init__.py index f5c600ce..37994ae4 100644 --- a/src/strands_evals/__init__.py +++ b/src/strands_evals/__init__.py @@ -1,4 +1,4 @@ -from . import evaluators, extractors, generators, providers, simulation, telemetry, types +from . import evaluators, extractors, generators, plugins, providers, simulation, telemetry, types from .case import Case from .experiment import Experiment from .simulation import ActorSimulator, UserSimulator @@ -12,6 +12,7 @@ "providers", "types", "generators", + "plugins", "simulation", "telemetry", "StrandsEvalsTelemetry",