|
19 | 19 | from google.adk.events import EventActions
|
20 | 20 | from google.adk.sessions import DatabaseSessionService
|
21 | 21 | from google.adk.sessions import InMemorySessionService
|
| 22 | +from google.adk.sessions.base_session_service import GetSessionConfig |
22 | 23 | from google.genai import types
|
23 | 24 |
|
24 | 25 |
|
@@ -183,7 +184,7 @@ def test_session_state(service_type):
|
183 | 184 |
|
184 | 185 |
|
185 | 186 | @pytest.mark.parametrize(
|
186 |
| - "service_type", [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] |
| 187 | + 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] |
187 | 188 | )
|
188 | 189 | def test_create_new_session_will_merge_states(service_type):
|
189 | 190 | session_service = get_session_service(service_type)
|
@@ -298,3 +299,57 @@ def test_append_event_complete(service_type):
|
298 | 299 | )
|
299 | 300 | == session
|
300 | 301 | )
|
| 302 | + |
| 303 | +@pytest.mark.parametrize('service_type', [SessionServiceType.IN_MEMORY]) |
| 304 | +def test_get_session_with_config(service_type): |
| 305 | + session_service = get_session_service(service_type) |
| 306 | + app_name = 'my_app' |
| 307 | + user_id = 'user' |
| 308 | + |
| 309 | + num_test_events = 5 |
| 310 | + session = session_service.create_session(app_name=app_name, user_id=user_id) |
| 311 | + for i in range(1, num_test_events + 1): |
| 312 | + event = Event(author='user', timestamp=i) |
| 313 | + session_service.append_event(session, event) |
| 314 | + |
| 315 | + # No config, expect all events to be returned. |
| 316 | + events = session_service.get_session( |
| 317 | + app_name=app_name, user_id=user_id, session_id=session.id |
| 318 | + ).events |
| 319 | + assert len(events) == num_test_events |
| 320 | + |
| 321 | + # Only expect the most recent 3 events. |
| 322 | + num_recent_events = 3 |
| 323 | + config = GetSessionConfig(num_recent_events=num_recent_events) |
| 324 | + events = session_service.get_session( |
| 325 | + app_name=app_name, user_id=user_id, session_id=session.id, config=config |
| 326 | + ).events |
| 327 | + assert len(events) == num_recent_events |
| 328 | + assert events[0].timestamp == num_test_events - num_recent_events + 1 |
| 329 | + |
| 330 | + # Only expect events after timestamp 4.0 (inclusive), i.e., 2 events. |
| 331 | + after_timestamp = 4.0 |
| 332 | + config = GetSessionConfig(after_timestamp=after_timestamp) |
| 333 | + events = session_service.get_session( |
| 334 | + app_name=app_name, user_id=user_id, session_id=session.id, config=config |
| 335 | + ).events |
| 336 | + assert len(events) == num_test_events - after_timestamp + 1 |
| 337 | + assert events[0].timestamp == after_timestamp |
| 338 | + |
| 339 | + # Expect no events if none are > after_timestamp. |
| 340 | + way_after_timestamp = num_test_events * 10 |
| 341 | + config = GetSessionConfig(after_timestamp=way_after_timestamp) |
| 342 | + events = session_service.get_session( |
| 343 | + app_name=app_name, user_id=user_id, session_id=session.id, config=config |
| 344 | + ).events |
| 345 | + assert len(events) == 0 |
| 346 | + |
| 347 | + # Both filters applied, i.e., of 3 most recent events, only 2 are after |
| 348 | + # timestamp 4.0, so expect 2 events. |
| 349 | + config = GetSessionConfig( |
| 350 | + after_timestamp=after_timestamp, num_recent_events=num_recent_events |
| 351 | + ) |
| 352 | + events = session_service.get_session( |
| 353 | + app_name=app_name, user_id=user_id, session_id=session.id, config=config |
| 354 | + ).events |
| 355 | + assert len(events) == num_test_events - after_timestamp + 1 |
0 commit comments