Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/app/agent/factory/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ def browser_agent(options: Chat):
skill_toolkit = message_integration.register_toolkits(skill_toolkit)

search_tools = SearchToolkit.get_can_use_tools(
options.project_id, agent_name=Agents.browser_agent
options.project_id,
agent_name=Agents.browser_agent,
)
if search_tools:
search_tools = message_integration.register_functions(search_tools)
Expand Down
3 changes: 2 additions & 1 deletion backend/app/agent/factory/developer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ async def developer_agent(options: Chat):
skill_toolkit = message_integration.register_toolkits(skill_toolkit)

search_tools = SearchToolkit.get_can_use_tools(
options.project_id, agent_name=Agents.developer_agent
options.project_id,
agent_name=Agents.developer_agent,
)
if search_tools:
search_tools = message_integration.register_functions(search_tools)
Expand Down
3 changes: 2 additions & 1 deletion backend/app/agent/factory/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ async def document_agent(options: Chat):
skill_toolkit = message_integration.register_toolkits(skill_toolkit)

search_tools = SearchToolkit.get_can_use_tools(
options.project_id, agent_name=Agents.document_agent
options.project_id,
agent_name=Agents.document_agent,
)
if search_tools:
search_tools = message_integration.register_functions(search_tools)
Expand Down
3 changes: 2 additions & 1 deletion backend/app/agent/factory/multi_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def multi_modal_agent(options: Chat):
skill_toolkit = message_integration.register_toolkits(skill_toolkit)

search_tools = SearchToolkit.get_can_use_tools(
options.project_id, agent_name=Agents.multi_modal_agent
options.project_id,
agent_name=Agents.multi_modal_agent,
)
if search_tools:
search_tools = message_integration.register_functions(search_tools)
Expand Down
3 changes: 2 additions & 1 deletion backend/app/agent/factory/social_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ async def social_media_agent(options: Chat):
user_id=options.skill_config_user_id(),
).get_tools(),
*SearchToolkit.get_can_use_tools(
options.project_id, agent_name=Agents.social_media_agent
options.project_id,
agent_name=Agents.social_media_agent,
),
# *DiscordToolkit(options.project_id).get_tools(),
# *GoogleSuiteToolkit(options.project_id).get_tools(),
Expand Down
47 changes: 32 additions & 15 deletions backend/app/agent/toolkit/search_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,21 @@ def cloud_search_google(
)
return res.json()

# @listen_toolkit(
# BaseSearchToolkit.search_duckduckgo,
# lambda _,
# query,
# source="text",
# max_results=5: f"Search DuckDuckGo with query '{query}', source '{source}', and max results {max_results}",
# lambda result: f"Search DuckDuckGo returned {len(result)} results",
# )
# def search_duckduckgo(self, query: str, source: str = "text", max_results: int = 5) -> list[dict[str, Any]]:
# return super().search_duckduckgo(query, source, max_results)
@listen_toolkit(
BaseSearchToolkit.search_duckduckgo,
lambda _,
query,
source="text",
number_of_result_pages=10: f"Search DuckDuckGo with query '{query}', source '{source}', and {number_of_result_pages} result pages",
lambda result: f"Search DuckDuckGo returned {len(result)} results",
)
def search_duckduckgo(
self,
query: str,
source: str = "text",
number_of_result_pages: int = 10,
) -> list[dict[str, Any]]:
return super().search_duckduckgo(query, source, number_of_result_pages)

# @listen_toolkit(
# BaseSearchToolkit.tavily_search,
Expand Down Expand Up @@ -365,9 +370,14 @@ def cloud_search_google(

@classmethod
def get_can_use_tools(
cls, api_task_id: str, agent_name: str | None = None
cls,
api_task_id: str,
agent_name: str | None = None,
) -> list[FunctionTool]:
search_toolkit = SearchToolkit(api_task_id, agent_name=agent_name)
search_toolkit = SearchToolkit(
api_task_id,
agent_name=agent_name,
)
tools = [
# FunctionTool(search_toolkit.search_wiki),
# FunctionTool(search_toolkit.search_duckduckgo),
Expand All @@ -380,10 +390,17 @@ def get_can_use_tools(
# if env("BRAVE_API_KEY"):
# tools.append(FunctionTool(search_toolkit.search_brave))

if (env("GOOGLE_API_KEY") and env("SEARCH_ENGINE_ID")) or env(
"cloud_api_key"
):
if env("GOOGLE_API_KEY") and env("SEARCH_ENGINE_ID"):
logger.info("Using search tool: search_google (user API keys)")
tools.append(FunctionTool(search_toolkit.search_google))
elif env("cloud_api_key"):
logger.info("Using search tool: search_google (cloud proxy)")
tools.append(FunctionTool(search_toolkit.search_google))
else:
logger.info(
"Using search tool: search_duckduckgo (no API keys configured)"
)
tools.append(FunctionTool(search_toolkit.search_duckduckgo))

# if env("TAVILY_API_KEY"):
# tools.append(FunctionTool(search_toolkit.tavily_search))
Expand Down
1 change: 1 addition & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"opentelemetry-api>=1.34.1",
"opentelemetry-sdk>=1.34.1",
"opentelemetry-exporter-otlp-proto-http>=1.34.1",
"duckduckgo-search>=7.0.0",
]


Expand Down
175 changes: 175 additions & 0 deletions backend/tests/app/agent/toolkit/test_search_toolkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========

import asyncio
from unittest.mock import MagicMock, patch

import pytest

from app.agent.toolkit.search_toolkit import SearchToolkit
from app.service.task import TaskLock, task_locks

pytestmark = pytest.mark.unit

_ENV_MOD = "app.agent.toolkit.search_toolkit.env"
_ENV_NOT_EMPTY_MOD = "app.agent.toolkit.search_toolkit.env_not_empty"
_TEST_TASK_ID = "test_task_search"


def _ensure_task_lock(task_id: str = _TEST_TASK_ID):
"""Ensure a task lock exists for the given task_id."""
if task_id not in task_locks:
task_locks[task_id] = TaskLock(
id=task_id, queue=asyncio.Queue(), human_input={}
)


def test_get_can_use_tools_duckduckgo_fallback_when_no_keys():
"""When no Google API keys or cloud_api_key, DuckDuckGo is used."""
with patch(_ENV_MOD, return_value=None):
tools = SearchToolkit.get_can_use_tools("test_task")
assert len(tools) == 1
assert "duckduckgo" in tools[0].func.__name__


def test_get_can_use_tools_google_api_when_keys_present():
"""When Google API keys are present, search_google is used."""

def mock_env(key, default=None):
return {
"GOOGLE_API_KEY": "test-key",
"SEARCH_ENGINE_ID": "test-cx",
}.get(key, default)

with patch(_ENV_MOD, side_effect=mock_env):
tools = SearchToolkit.get_can_use_tools("test_task")
assert len(tools) == 1
assert "search_google" == tools[0].func.__name__


def test_get_can_use_tools_cloud_api_key():
"""When cloud_api_key is present, search_google is used."""

def mock_env(key, default=None):
return {"cloud_api_key": "cloud-key"}.get(key, default)

with patch(_ENV_MOD, side_effect=mock_env):
tools = SearchToolkit.get_can_use_tools("test_task")
assert len(tools) == 1
assert "search_google" == tools[0].func.__name__


def test_get_can_use_tools_accepts_agent_name():
"""get_can_use_tools passes agent_name to the toolkit instance."""
with patch(_ENV_MOD, return_value=None):
tools = SearchToolkit.get_can_use_tools(
"test_task", agent_name="test_agent"
)
assert len(tools) == 1


def test_search_google_uses_user_keys():
"""search_google uses user-configured API keys when available."""
_ensure_task_lock()

def mock_env(key, default=None):
return {
"GOOGLE_API_KEY": "user-key",
"SEARCH_ENGINE_ID": "user-cx",
}.get(key, default)

toolkit = SearchToolkit(_TEST_TASK_ID)
with patch(_ENV_MOD, side_effect=mock_env):
with patch.object(
SearchToolkit.__bases__[0],
"search_google",
return_value=[{"result_id": 1, "title": "test"}],
) as mock_super:
result = toolkit.search_google("test query")
mock_super.assert_called_once()
assert result == [{"result_id": 1, "title": "test"}]


def test_search_google_falls_back_to_cloud():
"""search_google falls back to cloud search when no user keys."""
_ensure_task_lock()

toolkit = SearchToolkit(_TEST_TASK_ID)
with patch(_ENV_MOD, return_value=None):
with patch.object(
toolkit,
"cloud_search_google",
return_value=[{"result_id": 1, "title": "cloud"}],
) as mock_cloud:
result = toolkit.search_google("test query")
mock_cloud.assert_called_once_with("test query", "web", 10, 1)
assert result == [{"result_id": 1, "title": "cloud"}]


def test_get_can_use_tools_google_keys_no_duckduckgo():
"""When Google API keys are present, DuckDuckGo is NOT included."""

def mock_env(key, default=None):
return {
"GOOGLE_API_KEY": "test-key",
"SEARCH_ENGINE_ID": "test-cx",
}.get(key, default)

with patch(_ENV_MOD, side_effect=mock_env):
tools = SearchToolkit.get_can_use_tools("test_task")
names = [t.func.__name__ for t in tools]
assert "duckduckgo" not in " ".join(names)


def test_search_duckduckgo_delegates_to_base():
"""search_duckduckgo delegates to the base class method."""
_ensure_task_lock()

toolkit = SearchToolkit(_TEST_TASK_ID)
expected = [{"result_id": 1, "title": "duck result"}]

with patch.object(
SearchToolkit.__bases__[0],
"search_duckduckgo",
return_value=expected,
) as mock_super:
result = toolkit.search_duckduckgo("test query")
mock_super.assert_called_once()
assert result == expected


def test_cloud_search_google_calls_server():
"""cloud_search_google makes HTTP request to server proxy."""
toolkit = SearchToolkit("test_task")

mock_response = MagicMock()
mock_response.json.return_value = [{"result_id": 1, "title": "proxied"}]

with (
patch(
_ENV_NOT_EMPTY_MOD,
side_effect=lambda k: {
"SERVER_URL": "http://test-server",
"cloud_api_key": "test-cloud-key",
}[k],
),
patch(
"app.agent.toolkit.search_toolkit.httpx.get",
return_value=mock_response,
) as mock_get,
):
result = toolkit.cloud_search_google("test query")

mock_get.assert_called_once()
assert result == [{"result_id": 1, "title": "proxied"}]
Loading
Loading