Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 2 additions & 53 deletions aieng-eval-agents/aieng/agent_evals/async_client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging

from aieng.agent_evals.configs import Configs
from aieng.agent_evals.tools import ReadOnlySqlDatabase
from langfuse import Langfuse
from openai import AsyncOpenAI

Expand Down Expand Up @@ -63,8 +62,6 @@ def __init__(self, configs: Configs | None = None) -> None:
"""
self._configs: Configs | None = configs
self._openai_client: AsyncOpenAI | None = None
self._aml_db: ReadOnlySqlDatabase | None = None
self._report_generation_db: ReadOnlySqlDatabase | None = None
self._langfuse_client: Langfuse | None = None
self._otel_instrumented: bool = False
self._initialized: bool = False
Expand Down Expand Up @@ -98,46 +95,6 @@ def openai_client(self) -> AsyncOpenAI:
self._initialized = True
return self._openai_client

def report_generation_db(self, agent_name: str = "ReportGenerationAgent") -> ReadOnlySqlDatabase:
"""Get or create Report Generation database connection.

Returns
-------
ReadOnlySqlDatabase
The Report Generation database connection instance.
"""
if self._report_generation_db is None:
if self.configs.report_generation_db is None:
raise ValueError("Report Generation database configuration is missing.")

self._report_generation_db = ReadOnlySqlDatabase(
connection_uri=self.configs.report_generation_db.build_uri(),
agent_name=agent_name,
)
self._initialized = True

return self._report_generation_db

def aml_db(self, agent_name: str = "FraudInvestigationAnalyst") -> ReadOnlySqlDatabase:
"""Get or create AML database connection.

Returns
-------
ReadOnlySqlDatabase
The Report Generation database connection instance.
"""
if self._aml_db is None:
if self.configs.aml_db is None:
raise ValueError("AML database configuration is missing.")

self._aml_db = ReadOnlySqlDatabase(
connection_uri=self.configs.aml_db.build_uri(),
agent_name=agent_name,
)
self._initialized = True

return self._aml_db

@property
def langfuse_client(self) -> Langfuse:
"""Get or create Langfuse client.
Expand Down Expand Up @@ -183,21 +140,13 @@ def otel_instrumented(self, value: bool) -> None:
async def close(self) -> None:
"""Close all initialized async clients.

This method closes the OpenAI client, database connections, and Langfuse
client if they have been initialized.
This method closes the OpenAI client and Langfuse client
if they have been initialized.
"""
if self._openai_client is not None:
await self._openai_client.close()
self._openai_client = None

if self._aml_db is not None:
self._aml_db.close()
self._aml_db = None

if self._report_generation_db is not None:
self._report_generation_db.close()
self._report_generation_db = None

if self._langfuse_client is not None:
self._langfuse_client.flush()
self._langfuse_client = None
Expand Down
137 changes: 137 additions & 0 deletions aieng-eval-agents/aieng/agent_evals/db_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""Database connection manager for Gradio applications.

Provides centralized DB lifecycle management independent of async client handling,
avoiding circular imports with the tools package.
"""

import logging

from aieng.agent_evals.configs import Configs
from aieng.agent_evals.tools.sql_database import ReadOnlySqlDatabase


logger = logging.getLogger(__name__)


class DbManager:
"""Manages database connections with lazy initialization.

Parameters
----------
configs : Configs | None, optional
Configuration object. If ``None``, created lazily on first access.
"""

_singleton_instance: "DbManager | None" = None

@classmethod
def get_instance(cls) -> "DbManager":
"""Get the singleton instance of the DB manager.

Returns
-------
DbManager
The singleton instance of the DB manager.
"""
if cls._singleton_instance is None:
cls._singleton_instance = DbManager()
return cls._singleton_instance

def __init__(self, configs: Configs | None = None) -> None:
self._configs: Configs | None = configs
self._aml_db: ReadOnlySqlDatabase | None = None
self._report_generation_db: ReadOnlySqlDatabase | None = None

@property
def configs(self) -> Configs:
"""Get or create configs instance.

Returns
-------
Configs
The configuration instance.
"""
if self._configs is None:
self._configs = Configs() # type: ignore[call-arg]
return self._configs

@configs.setter
def configs(self, value: Configs) -> None:
"""Set the configs instance.

Parameters
----------
value : Configs
The configuration instance to set.
"""
self._configs = value

def aml_db(self, agent_name: str = "FraudInvestigationAnalyst") -> ReadOnlySqlDatabase:
"""Get or create the AML database connection.

Parameters
----------
agent_name : str, optional
Name of the agent using this connection,
by default ``"FraudInvestigationAnalyst"``.

Returns
-------
ReadOnlySqlDatabase
The AML database connection instance.

Raises
------
ValueError
If AML database configuration is missing.
"""
if self._aml_db is None:
if self.configs.aml_db is None:
raise ValueError("AML database configuration is missing.")

self._aml_db = ReadOnlySqlDatabase(
connection_uri=self.configs.aml_db.build_uri(),
agent_name=agent_name,
)

return self._aml_db

def report_generation_db(self, agent_name: str = "ReportGenerationAgent") -> ReadOnlySqlDatabase:
"""Get or create the Report Generation database connection.

Parameters
----------
agent_name : str, optional
Name of the agent using this connection,
by default ``"ReportGenerationAgent"``.

Returns
-------
ReadOnlySqlDatabase
The Report Generation database connection instance.

Raises
------
ValueError
If Report Generation database configuration is missing.
"""
if self._report_generation_db is None:
if self.configs.report_generation_db is None:
raise ValueError("Report Generation database configuration is missing.")

self._report_generation_db = ReadOnlySqlDatabase(
connection_uri=self.configs.report_generation_db.build_uri(),
agent_name=agent_name,
)

return self._report_generation_db

def close(self) -> None:
"""Dispose of all database connections."""
if self._aml_db is not None:
self._aml_db.close()
self._aml_db = None

if self._report_generation_db is not None:
self._report_generation_db.close()
self._report_generation_db = None
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Any

from aieng.agent_evals.async_client_manager import AsyncClientManager
from aieng.agent_evals.db_manager import DbManager
from aieng.agent_evals.langfuse import setup_langfuse_tracer
from aieng.agent_evals.report_generation.file_writer import ReportFileWriter
from google.adk.agents import Agent
Expand Down Expand Up @@ -60,6 +61,7 @@ def get_report_generation_agent(

# Get the client manager singleton instance
client_manager = AsyncClientManager.get_instance()
db_manager = DbManager.get_instance()
report_file_writer = ReportFileWriter(reports_output_path)

# Define an agent using Google ADK
Expand All @@ -68,8 +70,8 @@ def get_report_generation_agent(
model=client_manager.configs.default_worker_model,
instruction=instructions,
tools=[
client_manager.report_generation_db().execute,
client_manager.report_generation_db().get_schema_info,
db_manager.report_generation_db().execute,
db_manager.report_generation_db().get_schema_info,
report_file_writer.write_xlsx,
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any

from aieng.agent_evals.async_client_manager import AsyncClientManager
from aieng.agent_evals.db_manager import DbManager
from aieng.agent_evals.report_generation.agent import EventParser, EventType, get_report_generation_agent
from aieng.agent_evals.report_generation.prompts import (
MAIN_AGENT_INSTRUCTIONS,
Expand Down Expand Up @@ -106,6 +107,7 @@ async def evaluate(

try:
# Gracefully close the services
DbManager.get_instance().close()
await client_manager.close()
except Exception as e:
logger.warning(f"Client manager services not closed successfully: {e}")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Tests for AsyncClientManager singleton and client lifecycle."""

from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from aieng.agent_evals.async_client_manager import AsyncClientManager


@pytest.fixture(autouse=True)
def _reset_singleton():
"""Reset AsyncClientManager singleton before and after each test."""
AsyncClientManager._singleton_instance = None
yield
AsyncClientManager._singleton_instance = None


class TestGetInstance:
"""Tests for the get_instance() class method."""

def test_returns_same_instance(self):
"""get_instance() always returns the same object."""
first = AsyncClientManager.get_instance()
second = AsyncClientManager.get_instance()
assert first is second

def test_constructor_creates_separate_instance(self):
"""Direct constructor creates a different object than get_instance()."""
singleton = AsyncClientManager.get_instance()
separate = AsyncClientManager()
assert singleton is not separate


class TestConfigs:
"""Tests for lazy config creation."""

def test_lazy_config_creation(self):
"""Accessing .configs creates a Configs instance when none was provided."""
manager = AsyncClientManager()
assert manager._configs is None
with patch("aieng.agent_evals.async_client_manager.Configs") as mock_configs_cls:
mock_instance = MagicMock()
mock_configs_cls.return_value = mock_instance
result = manager.configs
assert result is mock_instance


class TestClose:
"""Tests for close() method."""

@pytest.mark.asyncio
async def test_closes_openai_client(self):
"""close() closes the OpenAI client."""
manager = AsyncClientManager()
mock_client = AsyncMock()
manager._openai_client = mock_client
manager._initialized = True

await manager.close()

mock_client.close.assert_awaited_once()
assert manager._openai_client is None

@pytest.mark.asyncio
async def test_flushes_langfuse(self):
"""close() flushes and clears the Langfuse client."""
manager = AsyncClientManager()
mock_langfuse = MagicMock()
manager._langfuse_client = mock_langfuse
manager._initialized = True

await manager.close()

mock_langfuse.flush.assert_called_once()
assert manager._langfuse_client is None

@pytest.mark.asyncio
async def test_resets_initialized(self):
"""close() sets _initialized to False."""
manager = AsyncClientManager()
manager._initialized = True

await manager.close()

assert manager._initialized is False
Loading