Skip to content

feat: Add ReflectRetryPlugin to reflect from errors and retry when tool/model errors #1983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,13 @@ async def _call_llm_async(
with tracer.start_as_current_span('call_llm'):
if invocation_context.run_config.support_cfc:
invocation_context.live_request_queue = LiveRequestQueue()
async for llm_response in self.run_live(invocation_context):
responses_generator = self.run_live(invocation_context)
async for llm_response in self._run_and_handle_error(
responses_generator,
invocation_context,
llm_request,
model_response_event,
):
# Runs after_model_callback if it exists.
if altered_llm_response := await self._handle_after_model_callback(
invocation_context, llm_response, model_response_event
Expand All @@ -553,10 +559,16 @@ async def _call_llm_async(
# the counter beyond the max set value, then the execution is stopped
# right here, and exception is thrown.
invocation_context.increment_llm_call_count()
async for llm_response in llm.generate_content_async(
responses_generator = llm.generate_content_async(
llm_request,
stream=invocation_context.run_config.streaming_mode
== StreamingMode.SSE,
)
async for llm_response in self._run_and_handle_error(
responses_generator,
invocation_context,
llm_request,
model_response_event,
):
trace_call_llm(
invocation_context,
Expand Down Expand Up @@ -673,6 +685,43 @@ def _finalize_model_response_event(

return model_response_event

async def _run_and_handle_error(
self,
response_generator: AsyncGenerator[LlmResponse, None],
invocation_context: InvocationContext,
llm_request: LlmRequest,
model_response_event: Event,
) -> AsyncGenerator[LlmResponse, None]:
"""Runs the response generator and processes the error with plugins.
Args:
response_generator: The response generator to run.
invocation_context: The invocation context.
llm_request: The LLM request.
model_response_event: The model response event.
Yields:
A generator of LlmResponse.
"""
try:
async for response in response_generator:
yield response
except Exception as model_error:
callback_context = CallbackContext(
invocation_context, event_actions=model_response_event.actions
)
error_response = (
await invocation_context.plugin_manager.run_on_model_error_callback(
callback_context=callback_context,
llm_request=llm_request,
error=model_error,
)
)
if error_response is not None:
yield error_response
else:
raise model_error

def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm:
from ...agents.llm_agent import LlmAgent

Expand Down
18 changes: 15 additions & 3 deletions src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,21 @@ async def handle_function_calls_async(

# Step 3: Otherwise, proceed calling the tool normally.
if function_response is None:
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
try:
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
except Exception as tool_error:
error_response = await invocation_context.plugin_manager.run_on_tool_error_callback(
tool=tool,
tool_args=function_args,
tool_context=tool_context,
error=tool_error,
)
if error_response is not None:
function_response = error_response
else:
raise tool_error

# Step 4: Check if plugin after_tool_callback overrides the function
# response.
Expand Down
51 changes: 51 additions & 0 deletions src/google/adk/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,31 @@ async def after_model_callback(
"""
pass

async def on_model_error_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
error: Exception,
) -> Optional[LlmResponse]:
"""Callback executed when a model call encounters an error.
This callback provides an opportunity to handle model errors gracefully,
potentially providing alternative responses or recovery mechanisms.
Args:
callback_context: The context for the current agent call.
llm_request: The request that was sent to the model when the error
occurred.
error: The exception that was raised during model execution.
Returns:
An optional LlmResponse. If an LlmResponse is returned, it will be used
instead of propagating the error. Returning `None` allows the original
error to be raised.
"""
pass

async def before_tool_callback(
self,
*,
Expand Down Expand Up @@ -315,3 +340,29 @@ async def after_tool_callback(
result.
"""
pass

async def on_tool_error_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
error: Exception,
) -> Optional[dict]:
"""Callback executed when a tool call encounters an error.
This callback provides an opportunity to handle tool errors gracefully,
potentially providing alternative responses or recovery mechanisms.
Args:
tool: The tool instance that encountered an error.
tool_args: The arguments that were passed to the tool.
tool_context: The context specific to the tool execution.
error: The exception that was raised during tool execution.
Returns:
An optional dictionary. If a dictionary is returned, it will be used as
the tool response instead of propagating the error. Returning `None`
allows the original error to be raised.
"""
pass
34 changes: 34 additions & 0 deletions src/google/adk/plugins/plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
"after_tool_callback",
"before_model_callback",
"after_model_callback",
"on_tool_error_callback",
"on_model_error_callback",
]

logger = logging.getLogger("google_adk." + __name__)
Expand Down Expand Up @@ -195,6 +197,21 @@ async def run_after_tool_callback(
result=result,
)

async def run_on_model_error_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
error: Exception,
) -> Optional[LlmResponse]:
"""Runs the `on_model_error_callback` for all plugins."""
return await self._run_callbacks(
"on_model_error_callback",
callback_context=callback_context,
llm_request=llm_request,
error=error,
)

async def run_before_model_callback(
self, *, callback_context: CallbackContext, llm_request: LlmRequest
) -> Optional[LlmResponse]:
Expand All @@ -215,6 +232,23 @@ async def run_after_model_callback(
llm_response=llm_response,
)

async def run_on_tool_error_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
error: Exception,
) -> Optional[dict]:
"""Runs the `on_tool_error_callback` for all plugins."""
return await self._run_callbacks(
"on_tool_error_callback",
tool=tool,
tool_args=tool_args,
tool_context=tool_context,
error=error,
)

async def _run_callbacks(
self, callback_name: PluginCallbackName, **kwargs: Any
) -> Optional[Any]:
Expand Down
171 changes: 171 additions & 0 deletions src/google/adk/plugins/reflect_retry_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import json
from typing import Any
from typing import Literal
from typing import Optional

from pydantic import BaseModel

from ..tools.base_tool import BaseTool
from ..tools.tool_context import ToolContext
from .base_plugin import BasePlugin


class ReflectAndRetryPluginResponse(BaseModel):
"""Response from ReflectAndRetryPlugin."""

response_type: Literal[str] = "ERROR_HANDLED_BY_REFLEX_AND_RETRY_PLUGIN"
error_type: str = ""
error_details: str = ""
retry_count: int = 0
reflection_guidance: str = ""


class ReflectAndRetryPlugin(BasePlugin):
"""A plugin that provides error recovery through reflection and retry logic.

When tool calls fail with exception, this plugin generates instructional
responses that encourage the model to reflect on the error and try a
different approach, rather than simply propagating the error.

This plugin is particularly useful for handling transient errors, API
limitations, or cases where the model might need to adjust its strategy
based on encountered obstacles.

Example:
>>> reflect_retry_plugin = ReflectAndRetryPlugin()
>>> runner = Runner(
... agents=[my_agent],
... plugins=[reflect_retry_plugin],
... )
"""

def __init__(self, name: str = "reflect_retry_plugin", max_retries: int = 3):
"""Initialize the reflect and retry plugin.

Args:
name: The name of the plugin instance.
max_retries: Maximum number of retries to attempt before giving up.
"""
super().__init__(name)
self.max_retries = max_retries
self._retry_counts: dict[str, int] = {}

async def on_tool_error_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
error: Exception,
) -> Optional[dict]:
"""Handle tool execution errors with reflection and retry logic."""
retry_key = self._get_retry_key(
tool_context.invocation_id, f"tool:{tool.name}"
)

if not self._should_retry(retry_key):
return self._get_tool_retry_exceed_msg(tool, error)

retry_count = self._increment_retry_count(retry_key)

# Create a reflective response instead of propagating the error
return self._create_tool_reflection_response(
tool, tool_args, error, retry_count
)

def _get_retry_key(self, context_id: str, operation: str) -> str:
"""Generate a unique key for tracking retries."""
return f"{context_id}:{operation}"

def _should_retry(self, retry_key: str) -> bool:
"""Check if we should attempt a retry for this operation."""
current_count = self._retry_counts.get(retry_key, 0)
return current_count < self.max_retries

def _increment_retry_count(self, retry_key: str) -> int:
"""Increment and return the retry count for an operation."""
self._retry_counts[retry_key] = self._retry_counts.get(retry_key, 0) + 1
return self._retry_counts[retry_key]

def _format_error_details(self, error: Exception) -> str:
"""Format error details for inclusion in reflection message."""
error_type = type(error).__name__
error_message = str(error)
return f"{error_type}: {error_message}"

def _create_tool_reflection_response(
self,
tool: BaseTool,
tool_args: dict[str, Any],
error: Exception,
retry_count: int,
) -> dict[str, Any]:
"""Create a reflection response for tool errors."""
args_summary = json.dumps(tool_args, indent=2, default=str)
error_details = self._format_error_details(error)

reflection_message = f"""
The tool call to '{tool.name}' failed with the following error:

Error: {error_details}

Tool Arguments Used:
{args_summary}

**Reflection Instructions:**
When realizing the current approach won't work, think about the potential issues and explicitly try a different approach. Consider:

1. **Parameter Issues**: Are the arguments correctly formatted or within expected ranges?
2. **Alternative Methods**: Is there a different tool or approach that might work better?
3. **Error Context**: What does this specific error tell you about what went wrong?
4. **Incremental Steps**: Can you break down the task into smaller, more manageable steps?

This is retry attempt {retry_count} of {self.max_retries}. Please analyze the error and adjust your strategy accordingly.

Instead of repeating the same approach, explicitly state what you learned from this error and how you plan to modify your approach.
"""

return ReflectAndRetryPluginResponse(
error_type=type(error).__name__,
error_details=str(error),
retry_count=retry_count,
reflection_guidance=reflection_message.strip(),
).model_dump(mode="json")

def _get_tool_retry_exceed_msg(
self,
tool: BaseTool,
error: Exception,
) -> dict[str, Any]:
"""Create a reflection response for tool errors."""
reflection_message = f"""
The tool call to '{tool.name}' has failed {self.max_retries} times and has exceeded the maximum retry limit.

Last Error: {self._format_error_details(error)}

**Instructions:**
Do not attempt to use this tool ('{tool.name}') again for this task.
You must try a different approach, using a different tool or strategy to accomplish the goal.
"""
return ReflectAndRetryPluginResponse(
error_type=type(error).__name__,
error_details=str(error),
retry_count=self.max_retries,
reflection_guidance=reflection_message.strip(),
).model_dump(mode="json")
Loading