diff --git a/aieng-eval-agents/aieng/agent_evals/async_client_manager.py b/aieng-eval-agents/aieng/agent_evals/async_client_manager.py index a1f36f42..451912f0 100644 --- a/aieng-eval-agents/aieng/agent_evals/async_client_manager.py +++ b/aieng-eval-agents/aieng/agent_evals/async_client_manager.py @@ -7,7 +7,6 @@ import logging from aieng.agent_evals.configs import Configs -from aieng.agent_evals.tools import ReadOnlySqlDatabase from langfuse import Langfuse from openai import AsyncOpenAI @@ -63,8 +62,6 @@ def __init__(self, configs: Configs | None = None) -> None: """ self._configs: Configs | None = configs self._openai_client: AsyncOpenAI | None = None - self._aml_db: ReadOnlySqlDatabase | None = None - self._report_generation_db: ReadOnlySqlDatabase | None = None self._langfuse_client: Langfuse | None = None self._otel_instrumented: bool = False self._initialized: bool = False @@ -98,46 +95,6 @@ def openai_client(self) -> AsyncOpenAI: self._initialized = True return self._openai_client - def report_generation_db(self, agent_name: str = "ReportGenerationAgent") -> ReadOnlySqlDatabase: - """Get or create Report Generation database connection. - - Returns - ------- - ReadOnlySqlDatabase - The Report Generation database connection instance. - """ - if self._report_generation_db is None: - if self.configs.report_generation_db is None: - raise ValueError("Report Generation database configuration is missing.") - - self._report_generation_db = ReadOnlySqlDatabase( - connection_uri=self.configs.report_generation_db.build_uri(), - agent_name=agent_name, - ) - self._initialized = True - - return self._report_generation_db - - def aml_db(self, agent_name: str = "FraudInvestigationAnalyst") -> ReadOnlySqlDatabase: - """Get or create AML database connection. - - Returns - ------- - ReadOnlySqlDatabase - The Report Generation database connection instance. - """ - if self._aml_db is None: - if self.configs.aml_db is None: - raise ValueError("AML database configuration is missing.") - - self._aml_db = ReadOnlySqlDatabase( - connection_uri=self.configs.aml_db.build_uri(), - agent_name=agent_name, - ) - self._initialized = True - - return self._aml_db - @property def langfuse_client(self) -> Langfuse: """Get or create Langfuse client. @@ -183,21 +140,13 @@ def otel_instrumented(self, value: bool) -> None: async def close(self) -> None: """Close all initialized async clients. - This method closes the OpenAI client, database connections, and Langfuse - client if they have been initialized. + This method closes the OpenAI client and Langfuse client + if they have been initialized. """ if self._openai_client is not None: await self._openai_client.close() self._openai_client = None - if self._aml_db is not None: - self._aml_db.close() - self._aml_db = None - - if self._report_generation_db is not None: - self._report_generation_db.close() - self._report_generation_db = None - if self._langfuse_client is not None: self._langfuse_client.flush() self._langfuse_client = None diff --git a/aieng-eval-agents/aieng/agent_evals/db_manager.py b/aieng-eval-agents/aieng/agent_evals/db_manager.py new file mode 100644 index 00000000..af3e06b6 --- /dev/null +++ b/aieng-eval-agents/aieng/agent_evals/db_manager.py @@ -0,0 +1,137 @@ +"""Database connection manager for Gradio applications. + +Provides centralized DB lifecycle management independent of async client handling, +avoiding circular imports with the tools package. +""" + +import logging + +from aieng.agent_evals.configs import Configs +from aieng.agent_evals.tools.sql_database import ReadOnlySqlDatabase + + +logger = logging.getLogger(__name__) + + +class DbManager: + """Manages database connections with lazy initialization. + + Parameters + ---------- + configs : Configs | None, optional + Configuration object. If ``None``, created lazily on first access. + """ + + _singleton_instance: "DbManager | None" = None + + @classmethod + def get_instance(cls) -> "DbManager": + """Get the singleton instance of the DB manager. + + Returns + ------- + DbManager + The singleton instance of the DB manager. + """ + if cls._singleton_instance is None: + cls._singleton_instance = DbManager() + return cls._singleton_instance + + def __init__(self, configs: Configs | None = None) -> None: + self._configs: Configs | None = configs + self._aml_db: ReadOnlySqlDatabase | None = None + self._report_generation_db: ReadOnlySqlDatabase | None = None + + @property + def configs(self) -> Configs: + """Get or create configs instance. + + Returns + ------- + Configs + The configuration instance. + """ + if self._configs is None: + self._configs = Configs() # type: ignore[call-arg] + return self._configs + + @configs.setter + def configs(self, value: Configs) -> None: + """Set the configs instance. + + Parameters + ---------- + value : Configs + The configuration instance to set. + """ + self._configs = value + + def aml_db(self, agent_name: str = "FraudInvestigationAnalyst") -> ReadOnlySqlDatabase: + """Get or create the AML database connection. + + Parameters + ---------- + agent_name : str, optional + Name of the agent using this connection, + by default ``"FraudInvestigationAnalyst"``. + + Returns + ------- + ReadOnlySqlDatabase + The AML database connection instance. + + Raises + ------ + ValueError + If AML database configuration is missing. + """ + if self._aml_db is None: + if self.configs.aml_db is None: + raise ValueError("AML database configuration is missing.") + + self._aml_db = ReadOnlySqlDatabase( + connection_uri=self.configs.aml_db.build_uri(), + agent_name=agent_name, + ) + + return self._aml_db + + def report_generation_db(self, agent_name: str = "ReportGenerationAgent") -> ReadOnlySqlDatabase: + """Get or create the Report Generation database connection. + + Parameters + ---------- + agent_name : str, optional + Name of the agent using this connection, + by default ``"ReportGenerationAgent"``. + + Returns + ------- + ReadOnlySqlDatabase + The Report Generation database connection instance. + + Raises + ------ + ValueError + If Report Generation database configuration is missing. + """ + if self._report_generation_db is None: + if self.configs.report_generation_db is None: + raise ValueError("Report Generation database configuration is missing.") + + self._report_generation_db = ReadOnlySqlDatabase( + connection_uri=self.configs.report_generation_db.build_uri(), + agent_name=agent_name, + ) + + return self._report_generation_db + + def close(self) -> None: + """Dispose of all database connections.""" + if self._aml_db is not None: + self._aml_db.close() + self._aml_db = None + + if self._report_generation_db is not None: + self._report_generation_db.close() + self._report_generation_db = None diff --git a/aieng-eval-agents/aieng/agent_evals/report_generation/agent.py b/aieng-eval-agents/aieng/agent_evals/report_generation/agent.py index 4daf8a1d..5ac660dc 100644 --- a/aieng-eval-agents/aieng/agent_evals/report_generation/agent.py +++ b/aieng-eval-agents/aieng/agent_evals/report_generation/agent.py @@ -21,6 +21,7 @@ from typing import Any from aieng.agent_evals.async_client_manager import AsyncClientManager +from aieng.agent_evals.db_manager import DbManager from aieng.agent_evals.langfuse import setup_langfuse_tracer from aieng.agent_evals.report_generation.file_writer import ReportFileWriter from google.adk.agents import Agent @@ -60,6 +61,7 @@ def get_report_generation_agent( # Get the client manager singleton instance client_manager = AsyncClientManager.get_instance() + db_manager = DbManager.get_instance() report_file_writer = ReportFileWriter(reports_output_path) # Define an agent using Google ADK @@ -68,8 +70,8 @@ def get_report_generation_agent( model=client_manager.configs.default_worker_model, instruction=instructions, tools=[ - client_manager.report_generation_db().execute, - client_manager.report_generation_db().get_schema_info, + db_manager.report_generation_db().execute, + db_manager.report_generation_db().get_schema_info, report_file_writer.write_xlsx, ], ) diff --git a/aieng-eval-agents/aieng/agent_evals/report_generation/evaluation.py b/aieng-eval-agents/aieng/agent_evals/report_generation/evaluation.py index 9099170a..20fe60de 100644 --- a/aieng-eval-agents/aieng/agent_evals/report_generation/evaluation.py +++ b/aieng-eval-agents/aieng/agent_evals/report_generation/evaluation.py @@ -16,6 +16,7 @@ from typing import Any from aieng.agent_evals.async_client_manager import AsyncClientManager +from aieng.agent_evals.db_manager import DbManager from aieng.agent_evals.report_generation.agent import EventParser, EventType, get_report_generation_agent from aieng.agent_evals.report_generation.prompts import ( MAIN_AGENT_INSTRUCTIONS, @@ -106,6 +107,7 @@ async def evaluate( try: # Gracefully close the services + DbManager.get_instance().close() await client_manager.close() except Exception as e: logger.warning(f"Client manager services not closed successfully: {e}") diff --git a/aieng-eval-agents/tests/aieng/agent_evals/test_async_client_manager.py b/aieng-eval-agents/tests/aieng/agent_evals/test_async_client_manager.py new file mode 100644 index 00000000..93424e5a --- /dev/null +++ b/aieng-eval-agents/tests/aieng/agent_evals/test_async_client_manager.py @@ -0,0 +1,84 @@ +"""Tests for AsyncClientManager singleton and client lifecycle.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from aieng.agent_evals.async_client_manager import AsyncClientManager + + +@pytest.fixture(autouse=True) +def _reset_singleton(): + """Reset AsyncClientManager singleton before and after each test.""" + AsyncClientManager._singleton_instance = None + yield + AsyncClientManager._singleton_instance = None + + +class TestGetInstance: + """Tests for the get_instance() class method.""" + + def test_returns_same_instance(self): + """get_instance() always returns the same object.""" + first = AsyncClientManager.get_instance() + second = AsyncClientManager.get_instance() + assert first is second + + def test_constructor_creates_separate_instance(self): + """Direct constructor creates a different object than get_instance().""" + singleton = AsyncClientManager.get_instance() + separate = AsyncClientManager() + assert singleton is not separate + + +class TestConfigs: + """Tests for lazy config creation.""" + + def test_lazy_config_creation(self): + """Accessing .configs creates a Configs instance when none was provided.""" + manager = AsyncClientManager() + assert manager._configs is None + with patch("aieng.agent_evals.async_client_manager.Configs") as mock_configs_cls: + mock_instance = MagicMock() + mock_configs_cls.return_value = mock_instance + result = manager.configs + assert result is mock_instance + + +class TestClose: + """Tests for close() method.""" + + @pytest.mark.asyncio + async def test_closes_openai_client(self): + """close() closes the OpenAI client.""" + manager = AsyncClientManager() + mock_client = AsyncMock() + manager._openai_client = mock_client + manager._initialized = True + + await manager.close() + + mock_client.close.assert_awaited_once() + assert manager._openai_client is None + + @pytest.mark.asyncio + async def test_flushes_langfuse(self): + """close() flushes and clears the Langfuse client.""" + manager = AsyncClientManager() + mock_langfuse = MagicMock() + manager._langfuse_client = mock_langfuse + manager._initialized = True + + await manager.close() + + mock_langfuse.flush.assert_called_once() + assert manager._langfuse_client is None + + @pytest.mark.asyncio + async def test_resets_initialized(self): + """close() sets _initialized to False.""" + manager = AsyncClientManager() + manager._initialized = True + + await manager.close() + + assert manager._initialized is False diff --git a/aieng-eval-agents/tests/aieng/agent_evals/test_db_manager.py b/aieng-eval-agents/tests/aieng/agent_evals/test_db_manager.py new file mode 100644 index 00000000..65387219 --- /dev/null +++ b/aieng-eval-agents/tests/aieng/agent_evals/test_db_manager.py @@ -0,0 +1,161 @@ +"""Tests for DbManager singleton and database connection management.""" + +from unittest.mock import MagicMock, patch + +import pytest +from aieng.agent_evals.db_manager import DbManager + + +@pytest.fixture(autouse=True) +def _reset_singleton(): + """Reset DbManager singleton before and after each test.""" + DbManager._singleton_instance = None + yield + DbManager._singleton_instance = None + + +class TestGetInstance: + """Tests for the get_instance() class method.""" + + def test_returns_same_instance(self): + """get_instance() always returns the same object.""" + assert DbManager.get_instance() is DbManager.get_instance() + + def test_constructor_creates_separate_instance(self): + """Direct constructor creates a different object than get_instance().""" + singleton = DbManager.get_instance() + separate = DbManager() + assert singleton is not separate + + +class TestConfigHandling: + """Tests for lazy config creation and setter.""" + + def test_lazy_config_creation(self): + """Accessing .configs creates a Configs instance when none was provided.""" + manager = DbManager() + assert manager._configs is None + with patch("aieng.agent_evals.db_manager.Configs") as mock_configs_cls: + mock_instance = MagicMock() + mock_configs_cls.return_value = mock_instance + result = manager.configs + assert result is mock_instance + + def test_configs_setter(self): + """Setting .configs stores the value.""" + manager = DbManager() + mock_configs = MagicMock() + manager.configs = mock_configs + assert manager.configs is mock_configs + + +class TestAmlDb: + """Tests for aml_db() method.""" + + def test_raises_when_config_missing(self): + """aml_db() raises ValueError when aml_db config is None.""" + mock_configs = MagicMock() + mock_configs.aml_db = None + manager = DbManager(configs=mock_configs) + with pytest.raises(ValueError, match="AML database configuration is missing"): + manager.aml_db() + + @patch("aieng.agent_evals.db_manager.ReadOnlySqlDatabase") + def test_creates_correct_connection(self, mock_db_cls): + """aml_db() creates a ReadOnlySqlDatabase with the right URI.""" + mock_configs = MagicMock() + mock_configs.aml_db.build_uri.return_value = "sqlite:///test.db" + manager = DbManager(configs=mock_configs) + + result = manager.aml_db() + + mock_db_cls.assert_called_once_with( + connection_uri="sqlite:///test.db", + agent_name="FraudInvestigationAnalyst", + ) + assert result is mock_db_cls.return_value + + @patch("aieng.agent_evals.db_manager.ReadOnlySqlDatabase") + def test_returns_cached_instance(self, mock_db_cls): + """Repeated calls return the same instance without re-creating.""" + mock_configs = MagicMock() + mock_configs.aml_db.build_uri.return_value = "sqlite:///test.db" + manager = DbManager(configs=mock_configs) + + first = manager.aml_db() + second = manager.aml_db() + + assert first is second + assert mock_db_cls.call_count == 1 + + +class TestReportGenerationDb: + """Tests for report_generation_db() method.""" + + def test_raises_when_config_missing(self): + """report_generation_db() raises ValueError when config is None.""" + mock_configs = MagicMock() + mock_configs.report_generation_db = None + manager = DbManager(configs=mock_configs) + with pytest.raises(ValueError, match="Report Generation database configuration is missing"): + manager.report_generation_db() + + @patch("aieng.agent_evals.db_manager.ReadOnlySqlDatabase") + def test_creates_correct_connection(self, mock_db_cls): + """report_generation_db() creates a ReadOnlySqlDatabase with the right URI.""" + mock_configs = MagicMock() + mock_configs.report_generation_db.build_uri.return_value = "sqlite:///reports.db" + manager = DbManager(configs=mock_configs) + + result = manager.report_generation_db() + + mock_db_cls.assert_called_once_with( + connection_uri="sqlite:///reports.db", + agent_name="ReportGenerationAgent", + ) + assert result is mock_db_cls.return_value + + @patch("aieng.agent_evals.db_manager.ReadOnlySqlDatabase") + def test_returns_cached_instance(self, mock_db_cls): + """Repeated calls return the same instance without re-creating.""" + mock_configs = MagicMock() + mock_configs.report_generation_db.build_uri.return_value = "sqlite:///reports.db" + manager = DbManager(configs=mock_configs) + + first = manager.report_generation_db() + second = manager.report_generation_db() + + assert first is second + assert mock_db_cls.call_count == 1 + + +class TestClose: + """Tests for close() method.""" + + @patch("aieng.agent_evals.db_manager.ReadOnlySqlDatabase") + def test_disposes_both_connections(self, mock_db_cls): + """close() disposes both DB connections and sets them to None.""" + mock_aml = MagicMock() + mock_report = MagicMock() + mock_db_cls.side_effect = [mock_aml, mock_report] + + mock_configs = MagicMock() + mock_configs.aml_db.build_uri.return_value = "sqlite:///aml.db" + mock_configs.report_generation_db.build_uri.return_value = "sqlite:///reports.db" + manager = DbManager(configs=mock_configs) + + manager.aml_db() + manager.report_generation_db() + + manager.close() + + mock_aml.close.assert_called_once() + mock_report.close.assert_called_once() + assert manager._aml_db is None + assert manager._report_generation_db is None + + def test_idempotent_when_no_connections(self): + """close() is a no-op when no connections have been created.""" + mock_configs = MagicMock() + manager = DbManager(configs=mock_configs) + manager.close() # Should not raise diff --git a/implementations/aml_investigation/agent.py b/implementations/aml_investigation/agent.py index f3a2b6e3..0a4e01bb 100644 --- a/implementations/aml_investigation/agent.py +++ b/implementations/aml_investigation/agent.py @@ -19,8 +19,8 @@ import google.genai.types from aieng.agent_evals.aml_investigation.data import AnalystOutput, CaseRecord -from aieng.agent_evals.async_client_manager import AsyncClientManager from aieng.agent_evals.async_utils import rate_limited +from aieng.agent_evals.db_manager import DbManager from aieng.agent_evals.tools import ReadOnlySqlDatabase from dotenv import load_dotenv from google.adk.agents import Agent @@ -94,15 +94,13 @@ @lru_cache(maxsize=1) def _get_db() -> ReadOnlySqlDatabase: """Lazily construct the read-only database tool from environment configuration.""" - client_manager = AsyncClientManager().get_instance() - return client_manager.aml_db() + return DbManager.get_instance().aml_db() -async def _try_close_db() -> None: +def _try_close_db() -> None: """Close the lazily initialized database tool if it was created.""" if _get_db.cache_info().currsize: - client_manager = AsyncClientManager().get_instance() - await client_manager.close() + DbManager.get_instance().close() _get_db.cache_clear() @@ -289,7 +287,7 @@ async def _main() -> None: logger.info(" TP=%d FP=%d", tp, fp) logger.info(" FN=%d TN=%d", fn, tn) finally: - await _try_close_db() + _try_close_db() if __name__ == "__main__": diff --git a/implementations/report_generation/data/import_online_retail_data.py b/implementations/report_generation/data/import_online_retail_data.py index 79a2c835..baee0a90 100644 --- a/implementations/report_generation/data/import_online_retail_data.py +++ b/implementations/report_generation/data/import_online_retail_data.py @@ -35,7 +35,7 @@ def main(dataset_path: str): dataset_path : str The path to the CSV file containing the dataset. """ - client_manager = AsyncClientManager().get_instance() + client_manager = AsyncClientManager.get_instance() assert client_manager.configs.report_generation_db, "Report generation database configuration is missing" assert client_manager.configs.report_generation_db.database, "Report generation database path is missing" diff --git a/implementations/report_generation/demo.py b/implementations/report_generation/demo.py index 7771b9b7..a2ef39a4 100644 --- a/implementations/report_generation/demo.py +++ b/implementations/report_generation/demo.py @@ -14,6 +14,7 @@ import click import gradio as gr from aieng.agent_evals.async_client_manager import AsyncClientManager +from aieng.agent_evals.db_manager import DbManager from aieng.agent_evals.report_generation.agent import get_report_generation_agent from aieng.agent_evals.report_generation.prompts import MAIN_AGENT_INSTRUCTIONS from dotenv import load_dotenv @@ -144,6 +145,7 @@ def start_gradio_app(enable_trace: bool = True, enable_public_link: bool = False allowed_paths=[str(get_reports_output_path().absolute())], ) finally: + DbManager.get_instance().close() asyncio.run(AsyncClientManager.get_instance().close())