From bcf7aef1f2b99eebef3c398c48fca6cbca9a1ef3 Mon Sep 17 00:00:00 2001 From: GitMarco27 Date: Tue, 9 Sep 2025 14:18:25 +0000 Subject: [PATCH 1/2] feat: full async implementatio nof DatabaseSessionService --- pyproject.toml | 3 +- .../adk/sessions/database_session_service.py | 108 ++++++++++-------- .../sessions/test_session_service.py | 2 +- 3 files changed, 63 insertions(+), 50 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 714d58b1c5..2b36ca58ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,10 +106,11 @@ eval = [ test = [ # go/keep-sorted start "a2a-sdk>=0.3.0,<0.4.0;python_version>='3.10'", + "aiosqlite>=0.21.0", # For database session service tests "anthropic>=0.43.0", # For anthropic model tests "kubernetes>=29.0.0", # For GkeCodeExecutor "langchain-community>=0.3.17", - "langgraph>=0.2.60, <= 0.4.10", # For LangGraphAgent + "langgraph>=0.2.60, <= 0.4.10", # For LangGraphAgent "litellm>=1.75.5, <2.0.0", # For LiteLLM tests "llama-index-readers-file>=0.4.0", # For retrieval tests "openai>=1.100.2", # For LiteLLM diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 959524c689..6733dca627 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -29,20 +29,21 @@ from sqlalchemy import event from sqlalchemy import ForeignKeyConstraint from sqlalchemy import func +from sqlalchemy import select from sqlalchemy import Text from sqlalchemy.dialects import mysql from sqlalchemy.dialects import postgresql -from sqlalchemy.engine import create_engine -from sqlalchemy.engine import Engine from sqlalchemy.exc import ArgumentError +from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory +from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.inspection import inspect from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship -from sqlalchemy.orm import Session as DatabaseSessionFactory -from sqlalchemy.orm import sessionmaker from sqlalchemy.schema import MetaData from sqlalchemy.types import DateTime from sqlalchemy.types import PickleType @@ -390,11 +391,11 @@ def __init__(self, db_url: str, **kwargs: Any): # 2. Create all tables based on schema # 3. Initialize all properties try: - db_engine = create_engine(db_url, **kwargs) + db_engine = create_async_engine(db_url, **kwargs) if db_engine.dialect.name == "sqlite": # Set sqlite pragma to enable foreign keys constraints - event.listen(db_engine, "connect", set_sqlite_pragma) + event.listen(db_engine.sync_engine, "connect", set_sqlite_pragma) except Exception as e: if isinstance(e, ArgumentError): @@ -413,18 +414,23 @@ def __init__(self, db_url: str, **kwargs: Any): local_timezone = get_localzone() logger.info("Local timezone: %s", local_timezone) - self.db_engine: Engine = db_engine + self.db_engine: AsyncEngine = db_engine self.metadata: MetaData = MetaData() - self.inspector = inspect(self.db_engine) # DB session factory method - self.database_session_factory: sessionmaker[DatabaseSessionFactory] = ( - sessionmaker(bind=self.db_engine) - ) + self.database_session_factory: async_sessionmaker[ + DatabaseSessionFactory + ] = async_sessionmaker(bind=self.db_engine) + + # Flag to indicate if tables are created + self._tables_created = False - # Uncomment to recreate DB every time - # Base.metadata.drop_all(self.db_engine) - Base.metadata.create_all(self.db_engine) + async def _ensure_tables_created(self): + """Ensure database tables are created. This is called lazily.""" + if not self._tables_created: + async with self.db_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + self._tables_created = True @override async def create_session( @@ -440,12 +446,11 @@ async def create_session( # 3. Add the object to the table # 4. Build the session object with generated id # 5. Return the session - - with self.database_session_factory() as sql_session: - + await self._ensure_tables_created() + async with self.database_session_factory() as sql_session: # Fetch app and user states from storage - storage_app_state = sql_session.get(StorageAppState, (app_name)) - storage_user_state = sql_session.get( + storage_app_state = await sql_session.get(StorageAppState, (app_name)) + storage_user_state = await sql_session.get( StorageUserState, (app_name, user_id) ) @@ -485,9 +490,9 @@ async def create_session( state=session_state, ) sql_session.add(storage_session) - sql_session.commit() + await sql_session.commit() - sql_session.refresh(storage_session) + await sql_session.refresh(storage_session) # Merge states for response merged_state = _merge_state(app_state, user_state, session_state) @@ -503,11 +508,12 @@ async def get_session( session_id: str, config: Optional[GetSessionConfig] = None, ) -> Optional[Session]: + await self._ensure_tables_created() # 1. Get the storage session entry from session table # 2. Get all the events based on session id and filtering config # 3. Convert and return the session - with self.database_session_factory() as sql_session: - storage_session = sql_session.get( + async with self.database_session_factory() as sql_session: + storage_session = await sql_session.get( StorageSession, (app_name, user_id, session_id) ) if storage_session is None: @@ -519,24 +525,24 @@ async def get_session( else: timestamp_filter = True - storage_events = ( - sql_session.query(StorageEvent) + stmt = ( + select(StorageEvent) .filter(StorageEvent.app_name == app_name) .filter(StorageEvent.session_id == storage_session.id) .filter(StorageEvent.user_id == user_id) .filter(timestamp_filter) .order_by(StorageEvent.timestamp.desc()) - .limit( - config.num_recent_events - if config and config.num_recent_events - else None - ) - .all() ) + if config and config.num_recent_events: + stmt = stmt.limit(config.num_recent_events) + + result = await sql_session.execute(stmt) + storage_events = result.scalars().all() + # Fetch states from storage - storage_app_state = sql_session.get(StorageAppState, (app_name)) - storage_user_state = sql_session.get( + storage_app_state = await sql_session.get(StorageAppState, (app_name)) + storage_user_state = await sql_session.get( StorageUserState, (app_name, user_id) ) @@ -556,17 +562,19 @@ async def get_session( async def list_sessions( self, *, app_name: str, user_id: str ) -> ListSessionsResponse: - with self.database_session_factory() as sql_session: - results = ( - sql_session.query(StorageSession) + await self._ensure_tables_created() + async with self.database_session_factory() as sql_session: + stmt = ( + select(StorageSession) .filter(StorageSession.app_name == app_name) .filter(StorageSession.user_id == user_id) - .all() ) + result = await sql_session.execute(stmt) + results = result.scalars().all() # Fetch states from storage - storage_app_state = sql_session.get(StorageAppState, (app_name)) - storage_user_state = sql_session.get( + storage_app_state = await sql_session.get(StorageAppState, (app_name)) + storage_user_state = await sql_session.get( StorageUserState, (app_name, user_id) ) @@ -585,25 +593,27 @@ async def list_sessions( async def delete_session( self, app_name: str, user_id: str, session_id: str ) -> None: - with self.database_session_factory() as sql_session: + await self._ensure_tables_created() + async with self.database_session_factory() as sql_session: stmt = delete(StorageSession).where( StorageSession.app_name == app_name, StorageSession.user_id == user_id, StorageSession.id == session_id, ) - sql_session.execute(stmt) - sql_session.commit() + await sql_session.execute(stmt) + await sql_session.commit() @override async def append_event(self, session: Session, event: Event) -> Event: + await self._ensure_tables_created() if event.partial: return event # 1. Check if timestamp is stale # 2. Update session attributes based on event config # 3. Store event to table - with self.database_session_factory() as sql_session: - storage_session = sql_session.get( + async with self.database_session_factory() as sql_session: + storage_session = await sql_session.get( StorageSession, (session.app_name, session.user_id, session.id) ) @@ -617,8 +627,10 @@ async def append_event(self, session: Session, event: Event) -> Event: ) # Fetch states from storage - storage_app_state = sql_session.get(StorageAppState, (session.app_name)) - storage_user_state = sql_session.get( + storage_app_state = await sql_session.get( + StorageAppState, (session.app_name) + ) + storage_user_state = await sql_session.get( StorageUserState, (session.app_name, session.user_id) ) @@ -649,8 +661,8 @@ async def append_event(self, session: Session, event: Event) -> Event: sql_session.add(StorageEvent.from_event(session, event)) - sql_session.commit() - sql_session.refresh(storage_session) + await sql_session.commit() + await sql_session.refresh(storage_session) # Update timestamp with commit time session.last_update_time = storage_session.update_timestamp_tz diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index c2a3a1d98f..4e00eed881 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -35,7 +35,7 @@ def get_session_service( ): """Creates a session service for testing.""" if service_type == SessionServiceType.DATABASE: - return DatabaseSessionService('sqlite:///:memory:') + return DatabaseSessionService('sqlite+aiosqlite:///:memory:') return InMemorySessionService() From ebf7ec36c7d655a37bcc9c82e7fb0b5739e15013 Mon Sep 17 00:00:00 2001 From: GitMarco27 Date: Tue, 9 Sep 2025 14:29:57 +0000 Subject: [PATCH 2/2] fix: implement thread-safe creation in DatabaseSessionService --- .../adk/sessions/database_session_service.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 6733dca627..d49da1999c 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +import asyncio import copy from datetime import datetime from datetime import timezone @@ -424,13 +425,20 @@ def __init__(self, db_url: str, **kwargs: Any): # Flag to indicate if tables are created self._tables_created = False + # Lock to ensure thread-safe table creation + self._table_creation_lock = asyncio.Lock() async def _ensure_tables_created(self): """Ensure database tables are created. This is called lazily.""" - if not self._tables_created: - async with self.db_engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - self._tables_created = True + if self._tables_created: + return + + async with self._table_creation_lock: + # Double-check after acquiring the lock + if not self._tables_created: + async with self.db_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + self._tables_created = True @override async def create_session(