diff --git a/src/agents/art/config.py b/src/agents/art/config.py new file mode 100644 index 0000000..eb00a00 --- /dev/null +++ b/src/agents/art/config.py @@ -0,0 +1,106 @@ +"""ART agent configuration. + +Controls which MCP tool categories are available to each specialized agent. +All settings can be overridden via environment variables (``ART_`` prefix). + +**Category names** (see :data:`~agents.art.specialized_agents.TOOL_GROUPS`): + +- ``core_tools`` — general search and index operations (SearchIndexTool, etc.) +- ``search_relevance`` — all Search Relevance Workbench tools (experiments, + judgment lists, search configurations, query sets) +- ``experiment`` — experiment lifecycle only +- ``judgment`` — judgment list management only +- ``search_config`` — search configuration management only +- ``query_set`` — query set management only + +Individual tool names (e.g. ``GetExperimentTool``) can also be mixed in. + +**Examples** (via environment variables):: + + # Give the UBI agent access to data-distribution as well as core tools + ART_UBI_AGENT_TOOLS=core_tools,DataDistributionTool + + # Give the evaluation agent every search-relevance tool + ART_EVALUATION_AGENT_TOOLS=core_tools,search_relevance +""" + +from __future__ import annotations + +from pydantic import field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class ARTAgentConfig(BaseSettings): + """Configuration for the ART specialized agents' MCP tool access. + + Each field is a list of category names and/or individual tool names. + Set via a comma-separated environment variable, e.g.:: + + ART_HYPOTHESIS_AGENT_TOOLS=core_tools,experiment,search_config,query_set + + Implementation note: the fields are typed ``list[str] | str`` so that + pydantic-settings treats failed JSON parsing as a non-fatal parse failure + and passes the raw string through to the ``_normalise_list`` validator, + which converts it to ``list[str]``. At runtime the value is always a + ``list[str]``. + """ + + model_config = SettingsConfigDict( + env_prefix="ART_", + case_sensitive=False, + extra="ignore", + ) + + hypothesis_agent_tools: list[str] | str = [ + "core_tools", + "experiment", + "search_config", + "query_set", + "somTool" + ] + """Tool categories for the hypothesis agent. + + Defaults to search + experiment management + search configs + query sets. + Judgment lists are excluded — the hypothesis agent only does pairwise + sanity checks, not full offline evaluation. + """ + + evaluation_agent_tools: list[str] | str = [ + "core_tools", + "search_relevance", + ] + """Tool categories for the evaluation agent. + + Defaults to search + all Search Relevance Workbench tools, giving the + agent access to experiments, judgment lists, configs, and query sets. + """ + + ubi_agent_tools: list[str] | str = [ + "core_tools", + ] + """Tool categories for the user-behavior-analysis agent. + + Defaults to core search tools only — UBI analysis is read-only queries + against the ubi_queries and ubi_events indices. + """ + + @field_validator( + "hypothesis_agent_tools", + "evaluation_agent_tools", + "ubi_agent_tools", + mode="before", + ) + @classmethod + def _normalise_list(cls, v: object) -> list[str]: + """Normalise to ``list[str]``, trimming whitespace and dropping empty items. + + Handles both a raw comma-separated string (from an env var) and a + plain list (from direct instantiation or a JSON-formatted env var). + """ + if isinstance(v, str): + items: list[str] = v.split(",") + elif isinstance(v, list): + items = [str(i) for i in v] + else: + return v # type: ignore[return-value] + return [item.strip() for item in items if item.strip()] diff --git a/src/agents/art/specialized_agents.py b/src/agents/art/specialized_agents.py index 8195a8d..02f518d 100644 --- a/src/agents/art/specialized_agents.py +++ b/src/agents/art/specialized_agents.py @@ -13,6 +13,8 @@ from strands import Agent from strands.models.bedrock import BedrockModel +from agents.art.config import ARTAgentConfig +from agents.tool_filter import TOOL_GROUPS, _select_tools from utils.logging_helpers import get_logger, log_info_event from utils.monitored_tool import monitored_tool @@ -190,6 +192,9 @@ # Global variable to store MCP tools (will be set during initialization) _opensearch_tools: list = [] +# Agent tool configuration — reads ART_* env vars once at import time. +_art_config = ARTAgentConfig() + def set_opensearch_tools(tools: list[Any]) -> None: """Set the OpenSearch MCP tools to be used by specialized agents.""" @@ -234,9 +239,9 @@ async def hypothesis_agent(query: str) -> str: ) hypothesis_tools = [ - # OpenSearch MCP tools - *_opensearch_tools, - # Experiment tools + # OpenSearch MCP tools — driven by ART_HYPOTHESIS_AGENT_TOOLS + *_select_tools(_opensearch_tools, _art_config.hypothesis_agent_tools), + # Experiment aggregation (local computation, not an MCP tool) aggregate_experiment_results, ] @@ -291,9 +296,9 @@ async def evaluation_agent(query: str) -> str: # Combine OpenSearch MCP tools with evaluation-specific tools evaluation_tools = [ - # OpenSearch MCP tools - *_opensearch_tools, - # Experiment tools + # OpenSearch MCP tools — driven by ART_EVALUATION_AGENT_TOOLS + *_select_tools(_opensearch_tools, _art_config.evaluation_agent_tools), + # Experiment aggregation (local computation, not an MCP tool) aggregate_experiment_results, ] @@ -342,8 +347,8 @@ async def user_behavior_analysis_agent(query: str) -> str: ) ubi_tools = [ - # OpenSearch MCP tools - *_opensearch_tools, + # OpenSearch MCP tools — driven by ART_UBI_AGENT_TOOLS + *_select_tools(_opensearch_tools, _art_config.ubi_agent_tools), ] # Create specialized agent with UBI analytics focus diff --git a/src/agents/default_config.py b/src/agents/default_config.py new file mode 100644 index 0000000..96184b6 --- /dev/null +++ b/src/agents/default_config.py @@ -0,0 +1,58 @@ +"""Fallback agent configuration. + +Controls which MCP tool categories are available to the fallback agent. +The setting can be overridden via the ``FALLBACK_AGENT_TOOLS`` environment +variable. + +By default the fallback agent has access to **all** tools exposed by the MCP +server (``FALLBACK_AGENT_TOOLS`` is empty). Set it to a comma-separated list +of category names or individual tool names to restrict access:: + + # Give the fallback agent only core search tools + FALLBACK_AGENT_TOOLS=core_tools + + # Give the fallback agent core tools plus one SRW tool + FALLBACK_AGENT_TOOLS=core_tools,GetExperimentTool + +See :data:`~agents.tool_filter.TOOL_GROUPS` for valid category names. +""" + +from __future__ import annotations + +from pydantic import field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class FallbackAgentConfig(BaseSettings): + """Configuration for the fallback agent's MCP tool access. + + Set via a comma-separated environment variable, e.g.:: + + FALLBACK_AGENT_TOOLS=core_tools,search_relevance + + An empty value (the default) means all MCP tools are available. + """ + + model_config = SettingsConfigDict( + env_prefix="FALLBACK_", + case_sensitive=False, + extra="ignore", + ) + + agent_tools: list[str] | str = [] + """Tool filter for the fallback agent. + + Defaults to an empty list, which passes all MCP tools through unchanged. + """ + + @field_validator("agent_tools", mode="before") + @classmethod + def _normalise_list(cls, v: object) -> list[str]: + """Normalise to ``list[str]``, trimming whitespace and dropping empty items.""" + if isinstance(v, str): + items: list[str] = v.split(",") + elif isinstance(v, list): + items = [str(i) for i in v] + else: + return v # type: ignore[return-value] + return [item.strip() for item in items if item.strip()] diff --git a/src/agents/fallback_agent.py b/src/agents/fallback_agent.py index f4c6d21..670d824 100644 --- a/src/agents/fallback_agent.py +++ b/src/agents/fallback_agent.py @@ -12,6 +12,8 @@ from strands import Agent from strands.tools.mcp import MCPClient +from agents.default_config import FallbackAgentConfig +from agents.tool_filter import _select_tools from server.constants import DEFAULT_MCP_SERVER_URL from utils.logging_helpers import get_logger, log_info_event @@ -34,16 +36,24 @@ - If you don't have the right tool for a request, explain what's available """ +# Agent tool configuration — reads FALLBACK_* env vars once at import time. +_fallback_config = FallbackAgentConfig() + def create_fallback_agent( opensearch_url: str, headers: dict[str, str] | None = None ) -> Agent: - """Create the fallback agent with all OpenSearch MCP tools. + """Create the fallback agent with OpenSearch MCP tools. Connects to the OpenSearch MCP server via Streamable HTTP transport. The server URL defaults to ``http://localhost:3001/mcp`` and can be overridden with the ``MCP_SERVER_URL`` environment variable. + The set of tools available to the agent is controlled by the + ``FALLBACK_AGENT_TOOLS`` environment variable (comma-separated category + names or individual tool names). When the variable is unset or empty + all MCP tools are available. + Args: opensearch_url: OpenSearch cluster URL (informational — the MCP server is assumed to already be configured for this cluster). @@ -56,19 +66,23 @@ def create_fallback_agent( mcp_server_url = os.getenv("MCP_SERVER_URL", DEFAULT_MCP_SERVER_URL) mcp_client = MCPClient(lambda: streamablehttp_client(mcp_server_url, headers=headers)) + mcp_client.start() + + all_tools = list(mcp_client.list_tools_sync()) + tools = _select_tools(all_tools, _fallback_config.agent_tools) agent = Agent( system_prompt=FALLBACK_SYSTEM_PROMPT, - tools=[mcp_client], + tools=tools, ) - tool_count = len(agent.tool_registry.registry) log_info_event( logger, - f"Fallback agent initialized with {tool_count} MCP tools " + f"Fallback agent initialized with {len(tools)}/{len(all_tools)} MCP tools " f"(server={mcp_server_url}).", "fallback_agent.initialized", - tool_count=tool_count, + tool_count=len(tools), + total_tool_count=len(all_tools), mcp_server_url=mcp_server_url, opensearch_url=opensearch_url, ) diff --git a/src/agents/tool_filter.py b/src/agents/tool_filter.py new file mode 100644 index 0000000..7dc5ab2 --- /dev/null +++ b/src/agents/tool_filter.py @@ -0,0 +1,152 @@ +"""Shared MCP tool filtering utilities. + +Provides :data:`TOOL_GROUPS` (the category registry) and :func:`_select_tools` +(the runtime filter). Both the ART specialized agents and the fallback agent +import from here so tool-name definitions live in a single place. + +**Category names**: + +- ``core_tools`` — general search and index operations (SearchIndexTool, etc.) +- ``search_relevance`` — all Search Relevance Workbench tools (experiments, + judgment lists, search configurations, query sets) +- ``experiment`` — experiment lifecycle only +- ``judgment`` — judgment list management only +- ``search_config`` — search configuration management only +- ``query_set`` — query set management only + +Individual tool names (e.g. ``GetExperimentTool``) can also be used directly. +""" + +from __future__ import annotations + +from typing import Any + +from utils.logging_helpers import get_logger, log_info_event, log_warning_event + +logger = get_logger(__name__) + +# --------------------------------------------------------------------------- +# Tool group definitions +# Category names mirror the OpenSearch MCP Server's logical groupings. +# Each entry maps a category name to the exact tool names exposed by the server. +# --------------------------------------------------------------------------- +TOOL_GROUPS: dict[str, frozenset[str]] = { + # General search and index operations (OpenSearch MCP Server: "core_tools") + "core_tools": frozenset({ + "SearchIndexTool", + "ListIndexTool", + "CountTool", + "ExplainTool", + "DataDistributionTool", + "LogPatternAnalysisTool", + "GenericOpenSearchApiTool", + }), + # All Search Relevance Workbench tools (OpenSearch MCP Server: "search_relevance") + # Convenience meta-category — equivalent to experiment + judgment + + # search_config + query_set combined. + "search_relevance": frozenset({ + # experiment lifecycle + "CreateExperimentTool", + "GetExperimentTool", + "DeleteExperimentTool", + # judgment list management + "CreateJudgmentListTool", + "CreateLLMJudgmentListTool", + "CreateUBIJudgmentListTool", + "GetJudgmentListTool", + "DeleteJudgmentListTool", + # search configuration (query DSL templates) + "CreateSearchConfigurationTool", + "GetSearchConfigurationTool", + "DeleteSearchConfigurationTool", + # query set management + "CreateQuerySetTool", + "GetQuerySetTool", + "DeleteQuerySetTool", + "SampleQuerySetTool", + }), + # Fine-grained sub-categories for selective access + "experiment": frozenset({ + "CreateExperimentTool", + "GetExperimentTool", + "DeleteExperimentTool", + }), + "judgment": frozenset({ + "CreateJudgmentListTool", + "CreateLLMJudgmentListTool", + "CreateUBIJudgmentListTool", + "GetJudgmentListTool", + "DeleteJudgmentListTool", + }), + "search_config": frozenset({ + "CreateSearchConfigurationTool", + "GetSearchConfigurationTool", + "DeleteSearchConfigurationTool", + }), + "query_set": frozenset({ + "CreateQuerySetTool", + "GetQuerySetTool", + "DeleteQuerySetTool", + "SampleQuerySetTool", + }), +} + + +def _select_tools( + tools: list[Any], + filters: list[str] | None = None, +) -> list[Any]: + """Filter MCP tools by category name and/or explicit tool name. + + Each item in *filters* is first looked up as a category key in + :data:`TOOL_GROUPS`. If no matching category is found the item is treated + as a direct tool name. This means category names and individual tool names + can be freely mixed in the same list — which is exactly what the + ``ART_*_AGENT_TOOLS`` and ``FALLBACK_AGENT_TOOLS`` environment variables + accept. + + If *filters* is ``None`` or empty the full list is returned unchanged. + + Args: + tools: Full list of Strands MCP tool objects. + filters: Category names (from :data:`TOOL_GROUPS`) and/or explicit + ``tool_name`` strings to include. + + Returns: + Filtered list containing only the tools whose ``tool_name`` is in the + resolved allow-list. + """ + if not filters: + return tools + + allowed: set[str] = set() + for f in filters: + group = TOOL_GROUPS.get(f) + if group is not None: + allowed |= group + else: + # Not a known category — treat as a direct tool name. + allowed.add(f) + + selected = [t for t in tools if getattr(t, "tool_name", None) in allowed] + log_info_event( + logger, + f"[Agents] Tool filter applied: {len(selected)}/{len(tools)} tools selected " + f"(filters={filters})", + "agents.tools_filtered", + selected=len(selected), + total=len(tools), + ) + + actual_names = {getattr(t, "tool_name", None) for t in tools} - {None} + unmatched = allowed - actual_names + if unmatched: + log_warning_event( + logger, + f"[Agents] Configured tool(s) not available from MCP server: " + f"{sorted(unmatched)}. Check your tool configuration.", + "agents.tools_not_available", + unmatched=sorted(unmatched), + ) + + return selected diff --git a/tests/unit/test_art_config.py b/tests/unit/test_art_config.py new file mode 100644 index 0000000..9025383 --- /dev/null +++ b/tests/unit/test_art_config.py @@ -0,0 +1,93 @@ +""" +Unit tests for ARTAgentConfig. + +Covers: +- Default tool category values for each agent +- Comma-separated string parsing from environment variables +- Whitespace trimming +- Empty-value handling +""" + +import pytest + +from agents.art.config import ARTAgentConfig + +pytestmark = pytest.mark.unit + + +class TestARTAgentConfigDefaults: + """Default values are sane and use the correct category names.""" + + def test_hypothesis_agent_defaults(self): + config = ARTAgentConfig() + assert config.hypothesis_agent_tools == [ + "core_tools", + "experiment", + "search_config", + "query_set", + ] + + def test_evaluation_agent_defaults(self): + config = ARTAgentConfig() + assert config.evaluation_agent_tools == ["core_tools", "search_relevance"] + + def test_ubi_agent_defaults(self): + config = ARTAgentConfig() + assert config.ubi_agent_tools == ["core_tools"] + + def test_judgment_excluded_from_hypothesis_defaults(self): + """The hypothesis agent must not have judgment tools by default.""" + config = ARTAgentConfig() + assert "judgment" not in config.hypothesis_agent_tools + + def test_search_relevance_covers_evaluation_defaults(self): + """Evaluation agent uses the broad search_relevance category.""" + config = ARTAgentConfig() + assert "search_relevance" in config.evaluation_agent_tools + + +class TestARTAgentConfigEnvVarParsing: + """Comma-separated env vars are parsed into lists.""" + + def test_comma_separated_string(self, monkeypatch): + monkeypatch.setenv("ART_HYPOTHESIS_AGENT_TOOLS", "core_tools,experiment") + config = ARTAgentConfig() + assert config.hypothesis_agent_tools == ["core_tools", "experiment"] + + def test_single_item_string(self, monkeypatch): + monkeypatch.setenv("ART_UBI_AGENT_TOOLS", "core_tools") + config = ARTAgentConfig() + assert config.ubi_agent_tools == ["core_tools"] + + def test_whitespace_is_trimmed(self, monkeypatch): + monkeypatch.setenv( + "ART_EVALUATION_AGENT_TOOLS", "core_tools , search_relevance" + ) + config = ARTAgentConfig() + assert config.evaluation_agent_tools == ["core_tools", "search_relevance"] + + def test_empty_string_produces_empty_list(self, monkeypatch): + monkeypatch.setenv("ART_UBI_AGENT_TOOLS", "") + config = ARTAgentConfig() + assert config.ubi_agent_tools == [] + + def test_individual_tool_name_accepted(self, monkeypatch): + """A direct tool name (not a category) is accepted as-is.""" + monkeypatch.setenv("ART_UBI_AGENT_TOOLS", "core_tools,DataDistributionTool") + config = ARTAgentConfig() + assert config.ubi_agent_tools == ["core_tools", "DataDistributionTool"] + + def test_all_three_fields_independent(self, monkeypatch): + monkeypatch.setenv("ART_HYPOTHESIS_AGENT_TOOLS", "core_tools") + monkeypatch.setenv("ART_EVALUATION_AGENT_TOOLS", "search_relevance") + monkeypatch.setenv("ART_UBI_AGENT_TOOLS", "core_tools,CountTool") + config = ARTAgentConfig() + assert config.hypothesis_agent_tools == ["core_tools"] + assert config.evaluation_agent_tools == ["search_relevance"] + assert config.ubi_agent_tools == ["core_tools", "CountTool"] + + def test_direct_instantiation_overrides_env(self, monkeypatch): + """Directly passed values take precedence over env vars.""" + monkeypatch.setenv("ART_UBI_AGENT_TOOLS", "search_relevance") + config = ARTAgentConfig(ubi_agent_tools=["core_tools"]) + assert config.ubi_agent_tools == ["core_tools"] diff --git a/tests/unit/test_default_config.py b/tests/unit/test_default_config.py new file mode 100644 index 0000000..d9c50c1 --- /dev/null +++ b/tests/unit/test_default_config.py @@ -0,0 +1,68 @@ +""" +Unit tests for FallbackAgentConfig. + +Covers: +- Default value (empty list — all tools allowed) +- Comma-separated string parsing from environment variables +- Whitespace trimming +- Empty-value handling +- Direct instantiation overrides env vars +""" + +import pytest + +from agents.default_config import FallbackAgentConfig + +pytestmark = pytest.mark.unit + + +class TestFallbackAgentConfigDefaults: + """Default value is an empty list (all tools pass through).""" + + def test_default_is_empty_list(self): + config = FallbackAgentConfig() + assert config.agent_tools == [] + + def test_empty_list_means_all_tools_pass_through(self): + """Verify that _select_tools returns all tools when config is empty.""" + from unittest.mock import MagicMock + + from agents.tool_filter import _select_tools + + tools = [MagicMock(), MagicMock()] + config = FallbackAgentConfig() + assert _select_tools(tools, config.agent_tools) == tools + + +class TestFallbackAgentConfigEnvVarParsing: + """Comma-separated env vars are parsed into lists.""" + + def test_single_category(self, monkeypatch): + monkeypatch.setenv("FALLBACK_AGENT_TOOLS", "core_tools") + config = FallbackAgentConfig() + assert config.agent_tools == ["core_tools"] + + def test_comma_separated_categories(self, monkeypatch): + monkeypatch.setenv("FALLBACK_AGENT_TOOLS", "core_tools,search_relevance") + config = FallbackAgentConfig() + assert config.agent_tools == ["core_tools", "search_relevance"] + + def test_whitespace_is_trimmed(self, monkeypatch): + monkeypatch.setenv("FALLBACK_AGENT_TOOLS", "core_tools , search_relevance") + config = FallbackAgentConfig() + assert config.agent_tools == ["core_tools", "search_relevance"] + + def test_empty_string_produces_empty_list(self, monkeypatch): + monkeypatch.setenv("FALLBACK_AGENT_TOOLS", "") + config = FallbackAgentConfig() + assert config.agent_tools == [] + + def test_individual_tool_name_accepted(self, monkeypatch): + monkeypatch.setenv("FALLBACK_AGENT_TOOLS", "core_tools,GetExperimentTool") + config = FallbackAgentConfig() + assert config.agent_tools == ["core_tools", "GetExperimentTool"] + + def test_direct_instantiation_overrides_env(self, monkeypatch): + monkeypatch.setenv("FALLBACK_AGENT_TOOLS", "search_relevance") + config = FallbackAgentConfig(agent_tools=["core_tools"]) + assert config.agent_tools == ["core_tools"] diff --git a/tests/unit/test_tool_selection.py b/tests/unit/test_tool_selection.py new file mode 100644 index 0000000..5990786 --- /dev/null +++ b/tests/unit/test_tool_selection.py @@ -0,0 +1,262 @@ +""" +Unit tests for TOOL_GROUPS and _select_tools. + +Covers: +- TOOL_GROUPS category membership and completeness +- search_relevance meta-category is the union of the four sub-categories +- _select_tools: no filter, single category, multiple categories +- _select_tools: direct tool name, mixed category + tool name +- _select_tools: empty tools list, tools without tool_name attribute +- Config-driven filtering via ARTAgentConfig values +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from agents.tool_filter import TOOL_GROUPS, _select_tools + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_tool(name: str) -> MagicMock: + """Create a mock MCP tool object with the given tool_name.""" + t = MagicMock() + t.tool_name = name + return t + + +def _tool_names(tools: list) -> set[str]: + return {t.tool_name for t in tools} + + +# --------------------------------------------------------------------------- +# TOOL_GROUPS structure +# --------------------------------------------------------------------------- + +class TestToolGroups: + """TOOL_GROUPS has the expected categories and contents.""" + + def test_expected_categories_present(self): + assert set(TOOL_GROUPS) >= { + "core_tools", + "search_relevance", + "experiment", + "judgment", + "search_config", + "query_set", + } + + def test_core_tools_contains_search_tools(self): + assert {"SearchIndexTool", "ListIndexTool", "CountTool"} <= TOOL_GROUPS["core_tools"] + + def test_search_relevance_is_union_of_sub_categories(self): + expected = ( + TOOL_GROUPS["experiment"] + | TOOL_GROUPS["judgment"] + | TOOL_GROUPS["search_config"] + | TOOL_GROUPS["query_set"] + ) + assert TOOL_GROUPS["search_relevance"] == expected + + def test_search_relevance_does_not_overlap_core_tools(self): + assert TOOL_GROUPS["search_relevance"].isdisjoint(TOOL_GROUPS["core_tools"]) + + def test_sub_categories_are_subsets_of_search_relevance(self): + for cat in ("experiment", "judgment", "search_config", "query_set"): + assert TOOL_GROUPS[cat] <= TOOL_GROUPS["search_relevance"], ( + f"'{cat}' tools should all appear in 'search_relevance'" + ) + + def test_all_entries_are_frozensets(self): + for name, group in TOOL_GROUPS.items(): + assert isinstance(group, frozenset), f"TOOL_GROUPS['{name}'] should be a frozenset" + + +# --------------------------------------------------------------------------- +# _select_tools behaviour +# --------------------------------------------------------------------------- + +class TestSelectTools: + """_select_tools filters the tool list correctly.""" + + # --- no-op cases --- + + def test_none_filter_returns_all(self): + tools = [_make_tool("SearchIndexTool"), _make_tool("CreateExperimentTool")] + assert _select_tools(tools, None) == tools + + def test_empty_filter_returns_all(self): + tools = [_make_tool("SearchIndexTool"), _make_tool("CreateExperimentTool")] + assert _select_tools(tools, []) == tools + + def test_empty_tools_list_returns_empty(self): + assert _select_tools([], ["core_tools"]) == [] + + # --- category filtering --- + + def test_single_category(self): + core = [_make_tool(n) for n in TOOL_GROUPS["core_tools"]] + srw = [_make_tool("CreateExperimentTool")] + result = _select_tools(core + srw, ["core_tools"]) + assert _tool_names(result) == TOOL_GROUPS["core_tools"] + + def test_multiple_categories_returns_union(self): + all_tools = [ + _make_tool(n) + for n in TOOL_GROUPS["core_tools"] | TOOL_GROUPS["experiment"] | TOOL_GROUPS["judgment"] + ] + result = _select_tools(all_tools, ["experiment", "judgment"]) + expected = TOOL_GROUPS["experiment"] | TOOL_GROUPS["judgment"] + assert _tool_names(result) == expected + + def test_search_relevance_meta_category(self): + srw_names = TOOL_GROUPS["search_relevance"] + all_tools = [_make_tool(n) for n in srw_names | TOOL_GROUPS["core_tools"]] + result = _select_tools(all_tools, ["search_relevance"]) + assert _tool_names(result) == srw_names + + # --- direct tool name as filter --- + + def test_direct_tool_name_not_in_any_category(self): + """An item that is not a category key is treated as a direct tool name.""" + tools = [_make_tool("SearchIndexTool"), _make_tool("MyCustomTool")] + result = _select_tools(tools, ["MyCustomTool"]) + assert _tool_names(result) == {"MyCustomTool"} + + def test_direct_tool_name_in_existing_category(self): + """A single tool name that happens to be inside a category is also matched.""" + tools = [_make_tool("SearchIndexTool"), _make_tool("CountTool")] + result = _select_tools(tools, ["SearchIndexTool"]) + assert _tool_names(result) == {"SearchIndexTool"} + + # --- mixed category + direct tool name (the config use-case) --- + + def test_mixed_category_and_tool_name(self): + """Category names and individual tool names can be freely mixed.""" + tools = [ + _make_tool("SearchIndexTool"), # in core_tools + _make_tool("CountTool"), # in core_tools + _make_tool("DataDistributionTool"), # in core_tools + _make_tool("CreateExperimentTool"), # in experiment / search_relevance + ] + # Give only CountTool from core_tools plus CreateExperimentTool by name + result = _select_tools(tools, ["CountTool", "CreateExperimentTool"]) + assert _tool_names(result) == {"CountTool", "CreateExperimentTool"} + + def test_category_and_extra_tool_name_combined(self): + """Category expands normally; extra tool name adds on top.""" + core_tools = [_make_tool(n) for n in TOOL_GROUPS["core_tools"]] + extra = _make_tool("CreateExperimentTool") + result = _select_tools(core_tools + [extra], ["core_tools", "CreateExperimentTool"]) + assert _tool_names(result) == TOOL_GROUPS["core_tools"] | {"CreateExperimentTool"} + + # --- tools without tool_name --- + + def test_tool_without_tool_name_is_excluded(self): + """Tools that lack a tool_name attribute are silently excluded.""" + good = _make_tool("SearchIndexTool") + bad = MagicMock(spec=[]) # no tool_name attribute + result = _select_tools([good, bad], ["core_tools"]) + assert good in result + assert bad not in result + + def test_tool_with_none_tool_name_is_excluded(self): + t = MagicMock() + t.tool_name = None + result = _select_tools([t], ["core_tools"]) + assert result == [] + + # --- warning for unavailable tools --- + + def test_warns_when_configured_tool_not_in_mcp_server(self): + """A warning is emitted when a configured tool name has no match in the tool list.""" + tools = [_make_tool("SearchIndexTool")] + with patch("agents.tool_filter.log_warning_event") as mock_warn: + _select_tools(tools, ["SearchIndexTool", "NonExistentTool"]) + mock_warn.assert_called_once() + args = mock_warn.call_args + assert "NonExistentTool" in str(args) + + def test_no_warning_when_all_configured_tools_available(self): + """No warning is emitted when every configured tool is present in the tool list.""" + tools = [_make_tool(n) for n in TOOL_GROUPS["core_tools"]] + with patch("agents.tool_filter.log_warning_event") as mock_warn: + _select_tools(tools, ["core_tools"]) + mock_warn.assert_not_called() + + def test_warns_for_category_tools_not_in_mcp_server(self): + """Warning fires when a category expands to names absent from the tool list.""" + # Only provide one of the core tools + tools = [_make_tool("SearchIndexTool")] + with patch("agents.tool_filter.log_warning_event") as mock_warn: + _select_tools(tools, ["core_tools"]) + mock_warn.assert_called_once() + # The warning should mention the missing names + unmatched_arg = mock_warn.call_args.kwargs.get("unmatched") or mock_warn.call_args[1].get("unmatched") + missing = TOOL_GROUPS["core_tools"] - {"SearchIndexTool"} + assert set(unmatched_arg) == missing + + # --- ordering is preserved --- + + def test_order_is_preserved(self): + names = ["SearchIndexTool", "CountTool", "ListIndexTool"] + tools = [_make_tool(n) for n in names] + result = _select_tools(tools, ["core_tools"]) + assert [t.tool_name for t in result] == names + + +# --------------------------------------------------------------------------- +# Config-driven filtering (integration between ARTAgentConfig and _select_tools) +# --------------------------------------------------------------------------- + +class TestConfigDrivenFiltering: + """_select_tools works correctly when driven by ARTAgentConfig values.""" + + def test_default_hypothesis_tools_excludes_judgment(self): + from agents.art.config import ARTAgentConfig + config = ARTAgentConfig() + all_tools = [_make_tool(n) for n in TOOL_GROUPS["search_relevance"] | TOOL_GROUPS["core_tools"]] + result = _select_tools(all_tools, config.hypothesis_agent_tools) + result_names = _tool_names(result) + # No judgment tools should appear + assert result_names.isdisjoint(TOOL_GROUPS["judgment"]) + + def test_default_evaluation_tools_includes_all_srw(self): + from agents.art.config import ARTAgentConfig + config = ARTAgentConfig() + all_tools = [_make_tool(n) for n in TOOL_GROUPS["search_relevance"] | TOOL_GROUPS["core_tools"]] + result = _select_tools(all_tools, config.evaluation_agent_tools) + result_names = _tool_names(result) + assert TOOL_GROUPS["search_relevance"] <= result_names + + def test_default_ubi_tools_excludes_srw(self): + from agents.art.config import ARTAgentConfig + config = ARTAgentConfig() + all_tools = [_make_tool(n) for n in TOOL_GROUPS["search_relevance"] | TOOL_GROUPS["core_tools"]] + result = _select_tools(all_tools, config.ubi_agent_tools) + result_names = _tool_names(result) + assert result_names.isdisjoint(TOOL_GROUPS["search_relevance"]) + + def test_env_var_override_adds_individual_tool(self, monkeypatch): + """ART_UBI_AGENT_TOOLS=core_tools,DataDistributionTool works end-to-end.""" + from agents.art.config import ARTAgentConfig + monkeypatch.setenv("ART_UBI_AGENT_TOOLS", "core_tools,DataDistributionTool") + config = ARTAgentConfig() + # DataDistributionTool is already in core_tools, so this is a no-op in practice, + # but the filter should still match it without error. + tools = [_make_tool("DataDistributionTool"), _make_tool("CreateExperimentTool")] + result = _select_tools(tools, config.ubi_agent_tools) + assert _tool_names(result) == {"DataDistributionTool"} + + def test_env_var_override_restricts_evaluation_agent(self, monkeypatch): + from agents.art.config import ARTAgentConfig + monkeypatch.setenv("ART_EVALUATION_AGENT_TOOLS", "experiment") + config = ARTAgentConfig() + all_tools = [_make_tool(n) for n in TOOL_GROUPS["search_relevance"] | TOOL_GROUPS["core_tools"]] + result = _select_tools(all_tools, config.evaluation_agent_tools) + assert _tool_names(result) == TOOL_GROUPS["experiment"]