Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 56 additions & 6 deletions src/strands_evals/simulation/tool_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from strands.models.model import Model
from strands.tools.decorator import DecoratedFunctionTool, FunctionToolMetadata

from strands_evals.types.simulation.hook_events import PostCallHookEvent, PreCallHookEvent
from strands_evals.types.simulation.tool import DefaultToolResponse, RegisteredTool

from .prompt_templates.tool_response_generation import TOOL_RESPONSE_PROMPT_TEMPLATE
Expand Down Expand Up @@ -166,6 +167,8 @@ def __init__(
state_registry: StateRegistry | None = None,
model: Model | str | None = None,
max_tool_call_cache_size: int = 20,
pre_call_hook: Callable | None = None,
post_call_hook: Callable | None = None,
):
"""
Initialize a ToolSimulator instance.
Expand All @@ -178,10 +181,21 @@ def __init__(
Only used when creating a new StateRegistry (ignored if state_registry
is provided). Older calls are automatically evicted when limit is exceeded.
Default is 20.
pre_call_hook: Optional callable invoked before the LLM generates a tool response.
Receives a PreCallHookEvent with tool_name, parameters, state_key,
and previous_calls. If it returns a non-None dict, that dict is used
as the tool response (short-circuiting the LLM call) and cached in
the state registry. If it returns None, normal LLM simulation proceeds.
post_call_hook: Optional callable invoked after the LLM generates a tool response
but before it is cached. Receives a PostCallHookEvent with tool_name,
parameters, state_key, and response. Must return a (possibly modified)
response dict.
"""
self.model = model
self.state_registry = state_registry or StateRegistry(max_tool_call_cache_size=max_tool_call_cache_size)
self._registered_tools: dict[str, RegisteredTool] = {}
self._pre_call_hook = pre_call_hook
self._post_call_hook = post_call_hook

def _create_tool_wrapper(self, registered_tool: RegisteredTool):
"""
Expand Down Expand Up @@ -245,7 +259,35 @@ def _parse_simulated_response(self, result: AgentResult) -> dict[str, Any]:
return response_data

def _call_tool(self, registered_tool: RegisteredTool, parameters_string: str, state_key: str) -> dict[str, Any]:
"""Simulate a tool invocation and return the response."""
"""Simulate a tool invocation and return the response.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation reminds me of decorator pattern, which also applies changes pre and post a decorated function. Wondering if it might be better to refactor this with decorator pattern?

    def _apply_pre_call_hook(self, func: Callable) -> Callable:
        """Internal decorator that applies pre_call_hook logic."""
        if self._pre_call_hook is None:
            return func
        
        @wraps(func)
        def wrapper(registered_tool: RegisteredTool, parameters_string: str, state_key: str) -> dict[str, Any]:
            parameters = json.loads(parameters_string)
            current_state = self.state_registry.get_state(state_key)
            
            # Invoke pre-call hook
            event = PreCallHookEvent(
                tool_name=registered_tool.name,
                parameters=parameters,
                state_key=state_key,
                previous_calls=current_state.get("previous_calls", []),
            )
            hook_response = self._pre_call_hook(event)
            
            # Short-circuit if hook returns a response
            if hook_response is not None:
                if not isinstance(hook_response, dict):
                    raise TypeError(f"pre_call_hook must return a dict or None, got {type(hook_response).__name__}")
                self.state_registry.cache_tool_call(
                    registered_tool.name, state_key, hook_response, parameters=parameters
                )
                return hook_response
            
            # Otherwise, proceed with normal execution
            return func(registered_tool, parameters_string, state_key)
        
        return wrapper

    def _apply_post_call_hook(self, func: Callable) -> Callable:
        """Internal decorator that applies post_call_hook logic."""
        if self._post_call_hook is None:
            return func
        
        @wraps(func)
        def wrapper(registered_tool: RegisteredTool, parameters_string: str, state_key: str) -> dict[str, Any]:
            # Execute the function to get response
            response_data = func(registered_tool, parameters_string, state_key)
            
            # Apply post-call hook to modify response
            parameters = json.loads(parameters_string)
            event = PostCallHookEvent(
                tool_name=registered_tool.name,
                parameters=parameters,
                state_key=state_key,
                response=response_data,
            )
            modified_response = self._post_call_hook(event)
            
            if not isinstance(modified_response, dict):
                raise TypeError(f"post_call_hook must return a dict, got {type(modified_response).__name__}")
            
            return modified_response
        
        return wrapper

......

    def _create_tool_wrapper
        ......
        def wrapper(*args, **kwargs):
                # Apply decorators in reverse order (post first, then pre)
                # This ensures execution order is: pre -> core logic -> post
                call_func = self._call_tool
                call_func = self._apply_post_call_hook(call_func)
                call_func = self._apply_pre_call_hook(call_func)
                ......
                return call_func

I think the pros are a) self._call_tool is not touched, b) more modularized, c) it seems to me that the pre/post hooks need access to many members of TS. Instead of relying on Event classes we might rather give them full access to self. Though I also worry this seems an overkill instead of just modifying _call_tool. @poshinchen thoughts?

Copy link
Author

@kaghatim kaghatim Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the reason I went with hooks over subclassing was specifically because _call_tool is private. I didn't want to couple to its internals or commit its signature as public API. The event classes keep the contract narrow so _call_tool can be refactored freely. But I'm open to a different approach if you have a preference. I would be happy to rework it.

Edit: Good point on the decorator pattern. I think the inline approach is simpler here, it avoids re-parsing parameters_string across decorators and keeps the flow in one place. Happy to wait for @poshinchen's take.


If a pre_call_hook is configured and returns a non-None dict, that dict is used
as the tool response (short-circuiting the LLM call). The response is still cached.

If a post_call_hook is configured, it receives the LLM-generated response before
caching and may modify it.
"""
parameters = json.loads(parameters_string)
current_state = self.state_registry.get_state(state_key)

# Pre-call hook: may short-circuit the LLM call
if self._pre_call_hook is not None:
event = PreCallHookEvent(
tool_name=registered_tool.name,
parameters=parameters,
state_key=state_key,
previous_calls=current_state.get("previous_calls", []),
)
hook_response = self._pre_call_hook(event)
if hook_response is not None:
if not isinstance(hook_response, dict):
raise TypeError(f"pre_call_hook must return a dict or None, got {type(hook_response).__name__}")
self.state_registry.cache_tool_call(
registered_tool.name, state_key, hook_response, parameters=parameters
)
return hook_response

# Normal LLM simulation
# Get input schema from Strands tool decorator
input_schema_dict = registered_tool.function.tool_spec.get("inputSchema", {}).get("json", {})
input_schema = json.dumps(input_schema_dict, indent=2)
Expand All @@ -254,8 +296,6 @@ def _call_tool(self, registered_tool: RegisteredTool, parameters_string: str, st
output_schema = registered_tool.output_schema.model_json_schema()
output_schema_string = json.dumps(output_schema, indent=2)

current_state = self.state_registry.get_state(state_key)

prompt = TOOL_RESPONSE_PROMPT_TEMPLATE.format(
tool_name=registered_tool.name,
input_schema=input_schema,
Expand All @@ -268,9 +308,19 @@ def _call_tool(self, registered_tool: RegisteredTool, parameters_string: str, st

response_data = self._parse_simulated_response(result)

self.state_registry.cache_tool_call(
registered_tool.name, state_key, response_data, parameters=json.loads(parameters_string)
)
# Post-call hook: may modify the response before caching
if self._post_call_hook is not None:
event = PostCallHookEvent(
tool_name=registered_tool.name,
parameters=parameters,
state_key=state_key,
response=response_data,
)
response_data = self._post_call_hook(event)
if not isinstance(response_data, dict):
raise TypeError(f"post_call_hook must return a dict, got {type(response_data).__name__}")

self.state_registry.cache_tool_call(registered_tool.name, state_key, response_data, parameters=parameters)
return response_data

def tool(
Expand Down
38 changes: 38 additions & 0 deletions src/strands_evals/types/simulation/hook_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from dataclasses import dataclass, field
from typing import Any


@dataclass
class PreCallHookEvent:
"""
Event passed to pre_call_hook before the LLM generates a tool response.

Attributes:
tool_name: Name of the tool being called.
parameters: Parsed parameters for the tool call.
state_key: Key for the state (tool_name or share_state_id).
previous_calls: List of previous tool call records from the state registry.
"""

tool_name: str
parameters: dict[str, Any]
state_key: str
previous_calls: list[dict[str, Any]] = field(default_factory=list)


@dataclass
class PostCallHookEvent:
"""
Event passed to post_call_hook after the LLM generates a tool response.

Attributes:
tool_name: Name of the tool that was called.
parameters: Parsed parameters for the tool call.
state_key: Key for the state (tool_name or share_state_id).
response: The LLM-generated response dict, which the hook may modify.
Copy link
Collaborator

@ybdarrenwang ybdarrenwang Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also just some food for thoughts:

  • What if we only add 1 new class CallHookEvent, just fetch response after it's cached into previos_calls[-1], and update it if post hook is applied?
  • What if we just throw the entire TS self to the hooks?

My worry is these event types are ad-hoc and bonded to TS changes. Say if in the future we want access to more TS internal members and states, one has to touch these types again.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On one class with previous_calls[-1]: right now the post-hook runs before caching so it returns a modified response and that's what gets stored. Caching first and having the hook mutate previous_calls[-1] in place would work, but it makes the hook a side-effect function instead of a simple transform. I went with the return-a-value pattern to keep hooks easier to test and reason about.

On passing self: I kept the events narrow to avoid coupling hooks to TS internals, but I hear the concern about the types growing over time. Happy to rework if that's the preference.

"""

tool_name: str
parameters: dict[str, Any]
state_key: str
response: dict[str, Any] = field(default_factory=dict)
Loading