Skip to content

Commit f7e2a7a

Browse files
DeanChensjcopybara-github
authored andcommitted
chore: Make VertexAiSessionService fully asynchronous
This CL refactors VertexAiSessionService to use the asynchronous aio client for all Vertex AI API calls. This ensures that the service methods are non-blocking and can be used effectively in an asyncio environment. PiperOrigin-RevId: 824573356
1 parent 48ddd07 commit f7e2a7a

File tree

2 files changed

+47
-39
lines changed

2 files changed

+47
-39
lines changed

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import asyncio
1617
import datetime
1718
import json
1819
import logging
@@ -104,7 +105,7 @@ async def create_session(
104105

105106
if _is_vertex_express_mode(self._project, self._location):
106107
config['wait_for_completion'] = False
107-
api_response = api_client.agent_engines.sessions.create(
108+
api_response = await api_client.aio.agent_engines.sessions.create(
108109
name=f'reasoningEngines/{reasoning_engine_id}',
109110
user_id=user_id,
110111
config=config,
@@ -123,7 +124,7 @@ async def create_session(
123124
)
124125
async def _poll_session_resource():
125126
try:
126-
return api_client.agent_engines.sessions.get(
127+
return await api_client.aio.agent_engines.sessions.get(
127128
name=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}'
128129
)
129130
except ClientError:
@@ -135,11 +136,11 @@ async def _poll_session_resource():
135136
except Exception as exc:
136137
raise ValueError('Failed to create session.') from exc
137138

138-
get_session_response = api_client.agent_engines.sessions.get(
139+
get_session_response = await api_client.aio.agent_engines.sessions.get(
139140
name=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}'
140141
)
141142
else:
142-
api_response = api_client.agent_engines.sessions.create(
143+
api_response = await api_client.aio.agent_engines.sessions.create(
143144
name=f'reasoningEngines/{reasoning_engine_id}',
144145
user_id=user_id,
145146
config=config,
@@ -168,10 +169,28 @@ async def get_session(
168169
) -> Optional[Session]:
169170
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
170171
api_client = self._get_api_client()
172+
session_resource_name = (
173+
f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}'
174+
)
175+
176+
# Get session resource and events in parallel.
177+
list_events_kwargs = {}
178+
if config and not config.num_recent_events and config.after_timestamp:
179+
# Filter events based on timestamp.
180+
list_events_kwargs['config'] = {
181+
'filter': 'timestamp>="{}"'.format(
182+
datetime.datetime.fromtimestamp(
183+
config.after_timestamp, tz=datetime.timezone.utc
184+
).isoformat()
185+
)
186+
}
171187

172-
# Get session resource
173-
get_session_response = api_client.agent_engines.sessions.get(
174-
name=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
188+
get_session_response, events_iterator = await asyncio.gather(
189+
api_client.aio.agent_engines.sessions.get(name=session_resource_name),
190+
api_client.aio.agent_engines.sessions.events.list(
191+
name=session_resource_name,
192+
**list_events_kwargs,
193+
),
175194
)
176195

177196
if get_session_response.user_id != user_id:
@@ -187,29 +206,14 @@ async def get_session(
187206
state=getattr(get_session_response, 'session_state', None) or {},
188207
last_update_time=update_timestamp,
189208
)
190-
191-
list_events_kwargs = {}
192-
if config and not config.num_recent_events and config.after_timestamp:
193-
list_events_kwargs['config'] = {
194-
'filter': 'timestamp>="{}"'.format(
195-
datetime.datetime.fromtimestamp(
196-
config.after_timestamp, tz=datetime.timezone.utc
197-
).isoformat()
198-
)
199-
}
200-
201-
events_iterator = api_client.agent_engines.sessions.events.list(
202-
name=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
203-
**list_events_kwargs,
204-
)
205209
session.events += [
206210
_from_api_event(event)
207211
for event in events_iterator
208212
if event.timestamp.timestamp() <= update_timestamp
209213
]
210214

211-
# Filter events based on config
212215
if config:
216+
# Filter events based on num_recent_events.
213217
if config.num_recent_events:
214218
session.events = session.events[-config.num_recent_events :]
215219

@@ -226,7 +230,7 @@ async def list_sessions(
226230
config = {}
227231
if user_id is not None:
228232
config['filter'] = f'user_id="{user_id}"'
229-
sessions_iterator = api_client.agent_engines.sessions.list(
233+
sessions_iterator = await api_client.aio.agent_engines.sessions.list(
230234
name=f'reasoningEngines/{reasoning_engine_id}',
231235
config=config,
232236
)
@@ -251,7 +255,7 @@ async def delete_session(
251255
api_client = self._get_api_client()
252256

253257
try:
254-
api_client.agent_engines.sessions.delete(
258+
await api_client.aio.agent_engines.sessions.delete(
255259
name=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
256260
)
257261
except Exception as e:
@@ -308,7 +312,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
308312
)
309313
config['event_metadata'] = metadata_dict
310314

311-
api_client.agent_engines.sessions.events.append(
315+
await api_client.aio.agent_engines.sessions.events.append(
312316
name=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}',
313317
author=event.author,
314318
invocation_id=event.invocation_id,

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -235,22 +235,24 @@ def __init__(self) -> None:
235235
"""Initializes MockClient."""
236236
self.session_dict: dict[str, Any] = {}
237237
self.event_dict: dict[str, Tuple[List[Any], Optional[str]]] = {}
238-
self.agent_engines = mock.Mock()
239-
self.agent_engines.sessions.get.side_effect = self._get_session
240-
self.agent_engines.sessions.list.side_effect = self._list_sessions
241-
self.agent_engines.sessions.delete.side_effect = self._delete_session
242-
self.agent_engines.sessions.create.side_effect = self._create_session
243-
self.agent_engines.sessions.events.list.side_effect = self._list_events
244-
self.agent_engines.sessions.events.append.side_effect = self._append_event
238+
self.aio = mock.Mock()
239+
self.aio.agent_engines.sessions.get.side_effect = self._get_session
240+
self.aio.agent_engines.sessions.list.side_effect = self._list_sessions
241+
self.aio.agent_engines.sessions.delete.side_effect = self._delete_session
242+
self.aio.agent_engines.sessions.create.side_effect = self._create_session
243+
self.aio.agent_engines.sessions.events.list.side_effect = self._list_events
244+
self.aio.agent_engines.sessions.events.append.side_effect = (
245+
self._append_event
246+
)
245247
self.last_create_session_config: dict[str, Any] = {}
246248

247-
def _get_session(self, name: str):
249+
async def _get_session(self, name: str):
248250
session_id = name.split('/')[-1]
249251
if session_id in self.session_dict:
250252
return _convert_to_object(self.session_dict[session_id])
251253
raise api_core_exceptions.NotFound(f'Session not found: {session_id}')
252254

253-
def _list_sessions(self, name: str, config: dict[str, Any]):
255+
async def _list_sessions(self, name: str, config: dict[str, Any]):
254256
filter_val = config.get('filter', '')
255257
user_id_match = re.search(r'user_id="([^"]+)"', filter_val)
256258
if user_id_match:
@@ -271,11 +273,13 @@ def _list_sessions(self, name: str, config: dict[str, Any]):
271273
_convert_to_object(session) for session in self.session_dict.values()
272274
]
273275

274-
def _delete_session(self, name: str):
276+
async def _delete_session(self, name: str):
275277
session_id = name.split('/')[-1]
276278
self.session_dict.pop(session_id)
277279

278-
def _create_session(self, name: str, user_id: str, config: dict[str, Any]):
280+
async def _create_session(
281+
self, name: str, user_id: str, config: dict[str, Any]
282+
):
279283
self.last_create_session_config = config
280284
new_session_id = '4'
281285
self.session_dict[new_session_id] = {
@@ -299,7 +303,7 @@ def _create_session(self, name: str, user_id: str, config: dict[str, Any]):
299303
'response': self.session_dict['4'],
300304
})
301305

302-
def _list_events(self, name: str, **kwargs):
306+
async def _list_events(self, name: str, **kwargs):
303307
session_id = name.split('/')[-1]
304308
events = []
305309
if session_id in self.event_dict:
@@ -322,7 +326,7 @@ def _list_events(self, name: str, **kwargs):
322326
]
323327
return [_convert_to_object(event) for event in events]
324328

325-
def _append_event(
329+
async def _append_event(
326330
self,
327331
name: str,
328332
author: str,

0 commit comments

Comments
 (0)