Skip to content

Commit a985cc3

Browse files
DeanChensjcopybara-github
authored andcommitted
feat: Support return all sessions when user id is none
PiperOrigin-RevId: 819884236
1 parent b650181 commit a985cc3

File tree

6 files changed

+147
-40
lines changed

6 files changed

+147
-40
lines changed

src/google/adk/sessions/base_session_service.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,18 @@ async def get_session(
8383

8484
@abc.abstractmethod
8585
async def list_sessions(
86-
self, *, app_name: str, user_id: str
86+
self, *, app_name: str, user_id: Optional[str] = None
8787
) -> ListSessionsResponse:
88-
"""Lists all the sessions."""
88+
"""Lists all the sessions for a user.
89+
90+
Args:
91+
app_name: The name of the app.
92+
user_id: The ID of the user. If not provided, lists all sessions for all
93+
users.
94+
95+
Returns:
96+
A ListSessionsResponse containing the sessions.
97+
"""
8998

9099
@abc.abstractmethod
91100
async def delete_session(

src/google/adk/sessions/database_session_service.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -554,30 +554,42 @@ async def get_session(
554554

555555
@override
556556
async def list_sessions(
557-
self, *, app_name: str, user_id: str
557+
self, *, app_name: str, user_id: Optional[str] = None
558558
) -> ListSessionsResponse:
559559
with self.database_session_factory() as sql_session:
560-
results = (
561-
sql_session.query(StorageSession)
562-
.filter(StorageSession.app_name == app_name)
563-
.filter(StorageSession.user_id == user_id)
564-
.all()
560+
query = sql_session.query(StorageSession).filter(
561+
StorageSession.app_name == app_name
565562
)
563+
if user_id is not None:
564+
query = query.filter(StorageSession.user_id == user_id)
565+
results = query.all()
566566

567-
# Fetch states from storage
567+
# Fetch app state from storage
568568
storage_app_state = sql_session.get(StorageAppState, (app_name))
569-
storage_user_state = sql_session.get(
570-
StorageUserState, (app_name, user_id)
571-
)
572-
573569
app_state = storage_app_state.state if storage_app_state else {}
574-
user_state = storage_user_state.state if storage_user_state else {}
570+
571+
# Fetch user state(s) from storage
572+
user_states_map = {}
573+
if user_id is not None:
574+
storage_user_state = sql_session.get(
575+
StorageUserState, (app_name, user_id)
576+
)
577+
if storage_user_state:
578+
user_states_map[user_id] = storage_user_state.state
579+
else:
580+
all_user_states_for_app = (
581+
sql_session.query(StorageUserState)
582+
.filter(StorageUserState.app_name == app_name)
583+
.all()
584+
)
585+
for storage_user_state in all_user_states_for_app:
586+
user_states_map[storage_user_state.user_id] = storage_user_state.state
575587

576588
sessions = []
577589
for storage_session in results:
578590
session_state = storage_session.state
591+
user_state = user_states_map.get(storage_session.user_id, {})
579592
merged_state = _merge_state(app_state, user_state, session_state)
580-
581593
sessions.append(storage_session.to_session(state=merged_state))
582594
return ListSessionsResponse(sessions=sessions)
583595

src/google/adk/sessions/in_memory_session_service.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -201,31 +201,41 @@ def _merge_state(
201201

202202
@override
203203
async def list_sessions(
204-
self, *, app_name: str, user_id: str
204+
self, *, app_name: str, user_id: Optional[str] = None
205205
) -> ListSessionsResponse:
206206
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
207207

208208
def list_sessions_sync(
209-
self, *, app_name: str, user_id: str
209+
self, *, app_name: str, user_id: Optional[str] = None
210210
) -> ListSessionsResponse:
211211
logger.warning('Deprecated. Please migrate to the async method.')
212212
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
213213

214214
def _list_sessions_impl(
215-
self, *, app_name: str, user_id: str
215+
self, *, app_name: str, user_id: Optional[str] = None
216216
) -> ListSessionsResponse:
217217
empty_response = ListSessionsResponse()
218218
if app_name not in self.sessions:
219219
return empty_response
220-
if user_id not in self.sessions[app_name]:
220+
if user_id is not None and user_id not in self.sessions[app_name]:
221221
return empty_response
222222

223223
sessions_without_events = []
224-
for session in self.sessions[app_name][user_id].values():
225-
copied_session = copy.deepcopy(session)
226-
copied_session.events = []
227-
copied_session = self._merge_state(app_name, user_id, copied_session)
228-
sessions_without_events.append(copied_session)
224+
225+
if user_id is None:
226+
for user_id in self.sessions[app_name]:
227+
for session_id in self.sessions[app_name][user_id]:
228+
session = self.sessions[app_name][user_id][session_id]
229+
copied_session = copy.deepcopy(session)
230+
copied_session.events = []
231+
copied_session = self._merge_state(app_name, user_id, copied_session)
232+
sessions_without_events.append(copied_session)
233+
else:
234+
for session in self.sessions[app_name][user_id].values():
235+
copied_session = copy.deepcopy(session)
236+
copied_session.events = []
237+
copied_session = self._merge_state(app_name, user_id, copied_session)
238+
sessions_without_events.append(copied_session)
229239
return ListSessionsResponse(sessions=sessions_without_events)
230240

231241
@override

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,22 +200,25 @@ async def get_session(
200200

201201
@override
202202
async def list_sessions(
203-
self, *, app_name: str, user_id: str
203+
self, *, app_name: str, user_id: Optional[str] = None
204204
) -> ListSessionsResponse:
205205
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
206206
api_client = self._get_api_client()
207207

208208
sessions = []
209+
config = {}
210+
if user_id is not None:
211+
config['filter'] = f'user_id="{user_id}"'
209212
sessions_iterator = api_client.agent_engines.sessions.list(
210213
name=f'reasoningEngines/{reasoning_engine_id}',
211-
config={'filter': f'user_id="{user_id}"'},
214+
config=config,
212215
)
213216

214217
for api_session in sessions_iterator:
215218
sessions.append(
216219
Session(
217220
app_name=app_name,
218-
user_id=user_id,
221+
user_id=api_session.user_id,
219222
id=api_session.name.split('/')[-1],
220223
state=getattr(api_session, 'session_state', None) or {},
221224
last_update_time=api_session.update_time.timestamp(),

tests/unittests/sessions/test_session_service.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,70 @@ async def test_create_and_list_sessions(service_type):
116116
app_name=app_name, user_id=user_id
117117
)
118118
sessions = list_sessions_response.sessions
119-
for i in range(len(sessions)):
120-
assert sessions[i].id == session_ids[i]
121-
assert sessions[i].state == {'key': 'value' + session_ids[i]}
119+
assert len(sessions) == len(session_ids)
120+
assert {s.id for s in sessions} == set(session_ids)
121+
for session in sessions:
122+
assert session.state == {'key': 'value' + session.id}
123+
124+
125+
@pytest.mark.asyncio
126+
@pytest.mark.parametrize(
127+
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
128+
)
129+
async def test_list_sessions_all_users(service_type):
130+
session_service = get_session_service(service_type)
131+
app_name = 'my_app'
132+
user_id_1 = 'user1'
133+
user_id_2 = 'user2'
134+
135+
await session_service.create_session(
136+
app_name=app_name,
137+
user_id=user_id_1,
138+
session_id='session1a',
139+
state={'key': 'value1a'},
140+
)
141+
await session_service.create_session(
142+
app_name=app_name,
143+
user_id=user_id_1,
144+
session_id='session1b',
145+
state={'key': 'value1b'},
146+
)
147+
await session_service.create_session(
148+
app_name=app_name,
149+
user_id=user_id_2,
150+
session_id='session2a',
151+
state={'key': 'value2a'},
152+
)
153+
154+
# List sessions for user1
155+
list_sessions_response_1 = await session_service.list_sessions(
156+
app_name=app_name, user_id=user_id_1
157+
)
158+
sessions_1 = list_sessions_response_1.sessions
159+
assert len(sessions_1) == 2
160+
assert {s.id for s in sessions_1} == {'session1a', 'session1b'}
161+
for session in sessions_1:
162+
if session.id == 'session1a':
163+
assert session.state == {'key': 'value1a'}
164+
else:
165+
assert session.state == {'key': 'value1b'}
166+
167+
# List sessions for user2
168+
list_sessions_response_2 = await session_service.list_sessions(
169+
app_name=app_name, user_id=user_id_2
170+
)
171+
sessions_2 = list_sessions_response_2.sessions
172+
assert len(sessions_2) == 1
173+
assert sessions_2[0].id == 'session2a'
174+
assert sessions_2[0].state == {'key': 'value2a'}
175+
176+
# List sessions for all users
177+
list_sessions_response_all = await session_service.list_sessions(
178+
app_name=app_name, user_id=None
179+
)
180+
sessions_all = list_sessions_response_all.sessions
181+
assert len(sessions_all) == 3
182+
assert {s.id for s in sessions_all} == {'session1a', 'session1b', 'session2a'}
122183

123184

124185
@pytest.mark.asyncio

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -252,19 +252,22 @@ def _get_session(self, name: str):
252252
def _list_sessions(self, name: str, config: dict[str, Any]):
253253
filter_val = config.get('filter', '')
254254
user_id_match = re.search(r'user_id="([^"]+)"', filter_val)
255-
if not user_id_match:
256-
raise ValueError(f'Could not find user_id in filter: {filter_val}')
257-
user_id = user_id_match.group(1)
258-
259-
if user_id == 'user_with_pages':
255+
if user_id_match:
256+
user_id = user_id_match.group(1)
257+
if user_id == 'user_with_pages':
258+
return [
259+
_convert_to_object(MOCK_SESSION_JSON_PAGE1),
260+
_convert_to_object(MOCK_SESSION_JSON_PAGE2),
261+
]
260262
return [
261-
_convert_to_object(MOCK_SESSION_JSON_PAGE1),
262-
_convert_to_object(MOCK_SESSION_JSON_PAGE2),
263+
_convert_to_object(session)
264+
for session in self.session_dict.values()
265+
if session['user_id'] == user_id
263266
]
267+
268+
# No user filter, return all sessions
264269
return [
265-
_convert_to_object(session)
266-
for session in self.session_dict.values()
267-
if session['user_id'] == user_id
270+
_convert_to_object(session) for session in self.session_dict.values()
268271
]
269272

270273
def _delete_session(self, name: str):
@@ -475,6 +478,15 @@ async def test_list_sessions_with_pagination():
475478
assert sessions.sessions[1].id == 'page2'
476479

477480

481+
@pytest.mark.asyncio
482+
@pytest.mark.usefixtures('mock_get_api_client')
483+
async def test_list_sessions_all_users():
484+
session_service = mock_vertex_ai_session_service()
485+
sessions = await session_service.list_sessions(app_name='123', user_id=None)
486+
assert len(sessions.sessions) == 5
487+
assert {s.id for s in sessions.sessions} == {'1', '2', '3', 'page1', 'page2'}
488+
489+
478490
@pytest.mark.asyncio
479491
@pytest.mark.usefixtures('mock_get_api_client')
480492
async def test_create_session():

0 commit comments

Comments
 (0)