Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions haystack/components/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
_validate_tool_breakpoint_is_valid,
)
from haystack.core.pipeline.pipeline import Pipeline
from haystack.core.pipeline.utils import _deepcopy_with_exceptions
from haystack.core.pipeline.utils import _deepcopy_with_exceptions, warm_tools_on_component
from haystack.core.serialization import component_to_dict, default_from_dict, default_to_dict
from haystack.dataclasses import ChatMessage, ChatRole
from haystack.dataclasses.breakpoints import AgentBreakpoint, AgentSnapshot, PipelineSnapshot, ToolBreakpoint
Expand Down Expand Up @@ -401,7 +401,7 @@ def _runtime_checks(self, break_point: Optional[AgentBreakpoint], snapshot: Opti
"""
if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"):
raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run()'.")

warm_tools_on_component(self, "tools")
if break_point and snapshot:
raise ValueError(
"break_point and snapshot cannot be provided at the same time. The agent run will be aborted."
Expand Down
8 changes: 6 additions & 2 deletions haystack/core/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,10 +828,14 @@ def warm_up(self) -> None:
It's the node's responsibility to make sure this method can be called at every `Pipeline.run()`
without re-initializing everything.
"""
from haystack.core.pipeline.utils import warm_tools_on_component

for node in self.graph.nodes:
if hasattr(self.graph.nodes[node]["instance"], "warm_up"):
instance = self.graph.nodes[node]["instance"]
if hasattr(instance, "warm_up"):
logger.info("Warming up component {node}...", node=node)
self.graph.nodes[node]["instance"].warm_up()
instance.warm_up()
warm_tools_on_component(instance)

@staticmethod
def _create_component_span(
Expand Down
57 changes: 57 additions & 0 deletions haystack/core/pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import heapq
from collections.abc import Iterable
from copy import deepcopy
from functools import wraps
from itertools import count
Expand All @@ -14,6 +15,62 @@
logger = logging.getLogger(__name__)


def warm_tools_on_component(component: Any, field_name: Optional[str] = None) -> None:
"""
Warm any Tool or Toolset instances reachable from a component.
:param component: The component to search for tools
:param field_name: Optional specific field name to check instead of searching all attributes
"""
# Import locally to avoid circular dependencies
from haystack.tools.tool import Tool
from haystack.tools.toolset import Toolset

attributes = [field_name] if field_name else dir(component)

for attr_name in attributes:
try:
attr_value = getattr(component, attr_name)
except Exception as exc: # pragma: no cover - defensive
logger.debug(
"Failed to access attribute {attr_name} on component {component}: {exc}",
attr_name=attr_name,
component=component.__class__.__name__,
exc=exc,
)
continue

if attr_value is None or callable(attr_value):
continue

for candidate in _iter_tool_candidates(attr_value):
try:
if isinstance(candidate, (Tool, Toolset)):
logger.debug("Warming up tools for component {component}", component=component.__class__.__name__)
candidate.warm_up()
except Exception as exc: # pragma: no cover - defensive
logger.debug(
"Failed to warm tool candidate from attribute {attr_name} on component {component}: {exc}",
attr_name=attr_name,
component=component.__class__.__name__,
exc=exc,
)


def _iter_tool_candidates(value: Any) -> Iterable[Any]:
"""Yield potential Tool or Toolset instances from a value."""
from haystack.tools.tool import Tool
from haystack.tools.toolset import Toolset

if isinstance(value, (Tool, Toolset)):
return (value,)

if isinstance(value, Iterable) and not isinstance(value, (str, bytes, dict)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iterables are needed for Toolset? Could you please explain?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that's just trick to search for Tools. Essentially if it's an iterable (like a list) BUT NOT a string, bytes, or dict then return it so we can iterate through it. See above how we search for (Tool, Toolset) without encoding specifically ChatGenerator ToolInvoker or any other component (Agent in pipeline has tools) also some others tomorrow perhaps. Rather than hardcoding this we look through fields of a component to find tools to warmup

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps an overkill, can simplify as well.

return value

return ()


def _deepcopy_with_exceptions(obj: Any) -> Any:
"""
Attempts to perform a deep copy of the given object.
Expand Down
6 changes: 6 additions & 0 deletions haystack/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def tool_spec(self) -> dict[str, Any]:
"""
return {"name": self.name, "description": self.description, "parameters": self.parameters}

def warm_up(self) -> None:
"""
Warm up the Tool.
"""
pass

def invoke(self, **kwargs: Any) -> Any:
"""
Invoke the Tool with the provided keyword arguments.
Expand Down
73 changes: 72 additions & 1 deletion haystack/tools/toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ def __contains__(self, item: Any) -> bool:
return item in self.tools
return False

def warm_up(self) -> None:
"""
Warm up the Toolset.
"""
pass

def add(self, tool: Union[Tool, "Toolset"]) -> None:
"""
Add a new Tool or merge another Toolset.
Expand Down Expand Up @@ -262,7 +268,7 @@ def __add__(self, other: Union[Tool, "Toolset", list[Tool]]) -> "Toolset":
if isinstance(other, Tool):
combined_tools = self.tools + [other]
elif isinstance(other, Toolset):
combined_tools = self.tools + list(other)
return _ToolsetWrapper([self, other])
elif isinstance(other, list) and all(isinstance(item, Tool) for item in other):
combined_tools = self.tools + other
else:
Expand All @@ -289,3 +295,68 @@ def __getitem__(self, index):
:returns: The Tool at the specified index
"""
return self.tools[index]


class _ToolsetWrapper(Toolset):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you explain why we need _ToolsetWrapper?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we need it in this form (there should be something simpler), it was a hack to enable:

all_tools = perplexity_mcp + routing_mcp + some_other_tools

If they are all MCPToolset and not warmed up when addition above is executed (with current code) we'll just get garbage. We need to preserve their configs (URLs) and then after they've been warmed up they could be added (using our current code) - this way we effectively add them and preserve configs so that when we warm them up they all connect to the right servers.

"""
A wrapper that holds multiple toolsets and provides a unified interface.

This is used internally when combining different types of toolsets to preserve
their individual configurations while still being usable with ToolInvoker.
"""

def __init__(self, toolsets: list[Toolset]):
self.toolsets = toolsets
# Check for duplicate tool names across all toolsets
all_tools = []
for toolset in toolsets:
all_tools.extend(list(toolset))
_check_duplicate_tool_names(all_tools)
super().__init__(tools=all_tools)

def __iter__(self):
"""Iterate over all tools from all toolsets."""
for toolset in self.toolsets:
yield from toolset

def __contains__(self, item):
"""Check if a tool is in any of the toolsets."""
return any(item in toolset for toolset in self.toolsets)

def warm_up(self):
"""Warm up all toolsets."""
for toolset in self.toolsets:
toolset.warm_up()

def __len__(self):
"""Return total number of tools across all toolsets."""
return sum(len(toolset) for toolset in self.toolsets)

def __getitem__(self, index):
"""Get a tool by index across all toolsets."""
current_index = 0
for toolset in self.toolsets:
toolset_len = len(toolset)
if current_index + toolset_len > index:
return toolset[index - current_index]
current_index += toolset_len
raise IndexError("ToolsetWrapper index out of range")

def __add__(self, other):
"""Add another toolset or tool to this wrapper."""
# Import here to avoid circular reference issues
from haystack.tools.toolset import Toolset

if isinstance(other, Toolset):
# Add the toolset to our list
new_toolsets = self.toolsets + [other]
elif isinstance(other, Tool):
# Convert tool to a basic toolset and add it
new_toolsets = self.toolsets + [Toolset([other])]
elif isinstance(other, list) and all(isinstance(item, Tool) for item in other):
# Convert list of tools to a basic toolset and add it
new_toolsets = self.toolsets + [Toolset(other)]
else:
raise TypeError(f"Cannot add {type(other).__name__} to ToolsetWrapper")

return _ToolsetWrapper(new_toolsets)
Loading