-
Notifications
You must be signed in to change notification settings - Fork 2k
feat: full async implementation of DatabaseSessionService #2889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: full async implementation of DatabaseSessionService #2889
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @GitMarco27, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a comprehensive asynchronous implementation for the DatabaseSessionService
. The primary motivation is to facilitate the effective use of the ADK in fully asynchronous environments and API endpoints, specifically by preventing event loop blocking during database I/O operations. This change significantly enhances the service's compatibility and performance in modern async Python applications.
Highlights
- Asynchronous DatabaseSessionService: The
DatabaseSessionService
has been fully converted to use asynchronous patterns, enabling non-blocking database I/O operations in async environments. - Lazy Table Creation: Database tables are now created lazily upon the first database interaction, rather than during service initialization, improving startup performance.
- Dependency Update: The
aiosqlite
library has been added as a test dependency to support asynchronous SQLite connections. - Minimal Breaking Changes: The implementation aims for minimal breaking changes, primarily affecting database connection string formats for async drivers.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request does a great job of converting the DatabaseSessionService
to a fully asynchronous implementation. The changes correctly adopt async/await
patterns and modern SQLAlchemy 2.0 style for async operations, which will significantly benefit applications running in an async environment. The code is clean and the approach of lazy table creation is sound. I've found one potential race condition in the table creation logic and provided a detailed suggestion to resolve it using an asyncio.Lock
. With that fix, this will be a solid enhancement.
@gemini-code-assist review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request does an excellent job of converting the DatabaseSessionService
to a fully asynchronous implementation. The changes correctly adopt sqlalchemy.ext.asyncio
patterns, including the use of an async engine, async sessions, and the modern select
statement syntax. The introduction of lazy, thread-safe table creation via _ensure_tables_created
is a robust addition. My review focuses on performance optimizations, suggesting the use of asyncio.gather
to run independent database queries concurrently, which should reduce latency in several key methods. Overall, this is a high-quality and well-executed feature implementation.
@DeanChensj can you prioritize this? Database session service is very important component for production of the systems based on ADK, it's a real shame this part is still sync. |
+1 here -- this is a mission critical feature |
really waiting to see it integrated! |
Thank you @GitMarco27 very much needed. I took the liberty of pulling the changes out into a import asyncio
import logging
from datetime import datetime
from typing import Any, Optional
from google.adk.events.event import Event
from google.adk.sessions.base_session_service import BaseSessionService, GetSessionConfig, ListSessionsResponse
from google.adk.sessions.database_session_service import (
Base,
StorageAppState,
StorageEvent,
StorageSession,
StorageUserState,
_extract_state_delta,
_merge_state,
set_sqlite_pragma,
)
from google.adk.sessions.session import Session
from sqlalchemy import delete, event, select
from sqlalchemy.exc import ArgumentError
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory
from sqlalchemy.schema import MetaData
from typing_extensions import override
from tzlocal import get_localzone
logger = logging.getLogger("adk_handler." + __name__)
class AsyncDatabaseSessionService(BaseSessionService):
"""A session service that uses a database for storage."""
def __init__(self, db_url: str, **kwargs: Any):
"""Initialize the database session service with a database URL."""
# 1. Create DB engine for db connection
# 2. Create all tables based on schema
# 3. Initialize all properties
try:
db_engine = create_async_engine(db_url, **kwargs)
if db_engine.dialect.name == "sqlite":
# Set sqlite pragma to enable foreign keys constraints
event.listen(db_engine.sync_engine, "connect", set_sqlite_pragma)
except Exception as e:
if isinstance(e, ArgumentError):
raise ValueError(f"Invalid database URL format or argument '{db_url}'.") from e
if isinstance(e, ImportError):
raise ValueError(f"Database related module not found for URL '{db_url}'.") from e
raise ValueError(f"Failed to create database engine for URL '{db_url}'") from e
# Get the local timezone
local_timezone = get_localzone()
logger.info("Local timezone: %s", local_timezone)
self.db_engine: AsyncEngine = db_engine
self.metadata: MetaData = MetaData()
# DB session factory method
self.database_session_factory: async_sessionmaker[DatabaseSessionFactory] = async_sessionmaker(
bind=self.db_engine
)
# Flag to indicate if tables are created
self._tables_created = False
# Lock to ensure thread-safe table creation
self._table_creation_lock = asyncio.Lock()
async def _ensure_tables_created(self):
"""Ensure database tables are created. This is called lazily."""
if self._tables_created:
return
async with self._table_creation_lock:
# Double-check after acquiring the lock
if not self._tables_created:
async with self.db_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
self._tables_created = True
@override
async def create_session(
self,
*,
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
) -> Session:
# 1. Populate states.
# 2. Build storage session object
# 3. Add the object to the table
# 4. Build the session object with generated id
# 5. Return the session
await self._ensure_tables_created()
async with self.database_session_factory() as sql_session:
# Fetch app and user states from storage
storage_app_state = await sql_session.get(StorageAppState, (app_name))
storage_user_state = await sql_session.get(StorageUserState, (app_name, user_id))
app_state = storage_app_state.state if storage_app_state else {}
user_state = storage_user_state.state if storage_user_state else {}
# Create state tables if not exist
if not storage_app_state:
storage_app_state = StorageAppState(app_name=app_name, state={})
sql_session.add(storage_app_state)
if not storage_user_state:
storage_user_state = StorageUserState(app_name=app_name, user_id=user_id, state={})
sql_session.add(storage_user_state)
# Extract state deltas
app_state_delta, user_state_delta, session_state = _extract_state_delta(state)
# Apply state delta
app_state.update(app_state_delta)
user_state.update(user_state_delta)
# Store app and user state
if app_state_delta:
storage_app_state.state = app_state
if user_state_delta:
storage_user_state.state = user_state
# Store the session
storage_session = StorageSession(
app_name=app_name,
user_id=user_id,
id=session_id,
state=session_state,
)
sql_session.add(storage_session)
await sql_session.commit()
await sql_session.refresh(storage_session)
# Merge states for response
merged_state = _merge_state(app_state, user_state, session_state)
session = storage_session.to_session(state=merged_state)
return session
@override
async def get_session(
self,
*,
app_name: str,
user_id: str,
session_id: str,
config: Optional[GetSessionConfig] = None,
) -> Optional[Session]:
await self._ensure_tables_created()
# 1. Get the storage session entry from session table
# 2. Get all the events based on session id and filtering config
# 3. Convert and return the session
async with self.database_session_factory() as sql_session:
storage_session = await sql_session.get(StorageSession, (app_name, user_id, session_id))
if storage_session is None:
return None
if config and config.after_timestamp:
after_dt = datetime.fromtimestamp(config.after_timestamp)
timestamp_filter = StorageEvent.timestamp >= after_dt
else:
timestamp_filter = True
stmt = (
select(StorageEvent)
.filter(StorageEvent.app_name == app_name)
.filter(StorageEvent.session_id == storage_session.id)
.filter(StorageEvent.user_id == user_id)
.filter(timestamp_filter)
.order_by(StorageEvent.timestamp.desc())
)
if config and config.num_recent_events:
stmt = stmt.limit(config.num_recent_events)
result = await sql_session.execute(stmt)
storage_events = result.scalars().all()
# Fetch states from storage
storage_app_state = await sql_session.get(StorageAppState, (app_name))
storage_user_state = await sql_session.get(StorageUserState, (app_name, user_id))
app_state = storage_app_state.state if storage_app_state else {}
user_state = storage_user_state.state if storage_user_state else {}
session_state = storage_session.state
# Merge states
merged_state = _merge_state(app_state, user_state, session_state)
# Convert storage session to session
events = [e.to_event() for e in reversed(storage_events)]
session = storage_session.to_session(state=merged_state, events=events)
return session
@override
async def list_sessions(self, *, app_name: str, user_id: str) -> ListSessionsResponse:
await self._ensure_tables_created()
async with self.database_session_factory() as sql_session:
stmt = (
select(StorageSession)
.filter(StorageSession.app_name == app_name)
.filter(StorageSession.user_id == user_id)
)
result = await sql_session.execute(stmt)
results = result.scalars().all()
# Fetch states from storage
storage_app_state = await sql_session.get(StorageAppState, (app_name))
storage_user_state = await sql_session.get(StorageUserState, (app_name, user_id))
app_state = storage_app_state.state if storage_app_state else {}
user_state = storage_user_state.state if storage_user_state else {}
sessions = []
for storage_session in results:
session_state = storage_session.state
merged_state = _merge_state(app_state, user_state, session_state)
sessions.append(storage_session.to_session(state=merged_state))
return ListSessionsResponse(sessions=sessions)
@override
async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None:
await self._ensure_tables_created()
async with self.database_session_factory() as sql_session:
stmt = delete(StorageSession).where(
StorageSession.app_name == app_name,
StorageSession.user_id == user_id,
StorageSession.id == session_id,
)
await sql_session.execute(stmt)
await sql_session.commit()
@override
async def append_event(self, session: Session, event: Event) -> Event:
await self._ensure_tables_created()
if event.partial:
return event
# 1. Check if timestamp is stale
# 2. Update session attributes based on event config
# 3. Store event to table
async with self.database_session_factory() as sql_session:
storage_session = await sql_session.get(StorageSession, (session.app_name, session.user_id, session.id))
if storage_session.update_timestamp_tz > session.last_update_time:
raise ValueError(
"The last_update_time provided in the session object"
f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'} is"
" earlier than the update_time in the storage_session"
f" {datetime.fromtimestamp(storage_session.update_timestamp_tz):'%Y-%m-%d %H:%M:%S'}."
" Please check if it is a stale session."
)
# Fetch states from storage
storage_app_state = await sql_session.get(StorageAppState, (session.app_name))
storage_user_state = await sql_session.get(StorageUserState, (session.app_name, session.user_id))
app_state = storage_app_state.state if storage_app_state else {}
user_state = storage_user_state.state if storage_user_state else {}
session_state = storage_session.state
# Extract state delta
app_state_delta = {}
user_state_delta = {}
session_state_delta = {}
if event.actions:
if event.actions.state_delta:
app_state_delta, user_state_delta, session_state_delta = _extract_state_delta(
event.actions.state_delta
)
# Merge state and update storage
if app_state_delta:
app_state.update(app_state_delta)
storage_app_state.state = app_state
if user_state_delta:
user_state.update(user_state_delta)
storage_user_state.state = user_state
if session_state_delta:
session_state.update(session_state_delta)
storage_session.state = session_state
sql_session.add(StorageEvent.from_event(session, event))
await sql_session.commit()
await sql_session.refresh(storage_session)
# Update timestamp with commit time
session.last_update_time = storage_session.update_timestamp_tz
# Also update the in-memory session
await super().append_event(session=session, event=event)
return event |
Implement Full async DatabaseSessionService
Target Issue: #1005
Overview
This PR introduces an asynchronous implementation of the
DatabaseSessionService
with minimal breaking changes. The primary goal is to enable effective use of ADK in fully async environments and API endpoints while avoiding event loop blocking during database I/O operations.Changes
DatabaseSessionService
to use async/await patterns throughoutTesting Plan
The implementation has been tested following the project's contribution guidelines:
Unit Tests
aiosqlite
Manual End-to-End Testing
asyncpg
driverThe implementation have been also tested using the following configurations for llm provider and Runner:
Breaking Changes