Skip to content

Commit a2129f8

Browse files
committed
fix: fix and test config.after_timestamp behavior in InMemorySessionService.get_session()
1 parent b691904 commit a2129f8

File tree

2 files changed

+59
-4
lines changed

2 files changed

+59
-4
lines changed

src/google/adk/sessions/in_memory_session_service.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,14 @@ def get_session(
9595
copied_session.events = copied_session.events[
9696
-config.num_recent_events :
9797
]
98-
elif config.after_timestamp:
99-
i = len(session.events) - 1
98+
if config.after_timestamp:
99+
i = len(copied_session.events) - 1
100100
while i >= 0:
101101
if copied_session.events[i].timestamp < config.after_timestamp:
102102
break
103103
i -= 1
104104
if i >= 0:
105-
copied_session.events = copied_session.events[i:]
105+
copied_session.events = copied_session.events[i + 1:]
106106

107107
return self._merge_state(app_name, user_id, copied_session)
108108

tests/unittests/sessions/test_session_service.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from google.adk.events import EventActions
2020
from google.adk.sessions import DatabaseSessionService
2121
from google.adk.sessions import InMemorySessionService
22+
from google.adk.sessions.base_session_service import GetSessionConfig
2223
from google.genai import types
2324

2425

@@ -183,7 +184,7 @@ def test_session_state(service_type):
183184

184185

185186
@pytest.mark.parametrize(
186-
"service_type", [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
187+
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
187188
)
188189
def test_create_new_session_will_merge_states(service_type):
189190
session_service = get_session_service(service_type)
@@ -298,3 +299,57 @@ def test_append_event_complete(service_type):
298299
)
299300
== session
300301
)
302+
303+
@pytest.mark.parametrize('service_type', [SessionServiceType.IN_MEMORY])
304+
def test_get_session_with_config(service_type):
305+
session_service = get_session_service(service_type)
306+
app_name = 'my_app'
307+
user_id = 'user'
308+
309+
num_test_events = 5
310+
session = session_service.create_session(app_name=app_name, user_id=user_id)
311+
for i in range(1, num_test_events + 1):
312+
event = Event(author='user', timestamp=i)
313+
session_service.append_event(session, event)
314+
315+
# No config, expect all events to be returned.
316+
events = session_service.get_session(
317+
app_name=app_name, user_id=user_id, session_id=session.id
318+
).events
319+
assert len(events) == num_test_events
320+
321+
# Only expect the most recent 3 events.
322+
num_recent_events = 3
323+
config = GetSessionConfig(num_recent_events=num_recent_events)
324+
events = session_service.get_session(
325+
app_name=app_name, user_id=user_id, session_id=session.id, config=config
326+
).events
327+
assert len(events) == num_recent_events
328+
assert events[0].timestamp == num_test_events - num_recent_events + 1
329+
330+
# Only expect events after timestamp 4.0 (inclusive), i.e., 2 events.
331+
after_timestamp = 4.0
332+
config = GetSessionConfig(after_timestamp=after_timestamp)
333+
events = session_service.get_session(
334+
app_name=app_name, user_id=user_id, session_id=session.id, config=config
335+
).events
336+
assert len(events) == num_test_events - after_timestamp + 1
337+
assert events[0].timestamp == after_timestamp
338+
339+
# Expect no events if none are > after_timestamp.
340+
way_after_timestamp = num_test_events * 10
341+
config = GetSessionConfig(after_timestamp=way_after_timestamp)
342+
events = session_service.get_session(
343+
app_name=app_name, user_id=user_id, session_id=session.id, config=config
344+
).events
345+
assert len(events) == 0
346+
347+
# Both filters applied, i.e., of 3 most recent events, only 2 are after
348+
# timestamp 4.0, so expect 2 events.
349+
config = GetSessionConfig(
350+
after_timestamp=after_timestamp, num_recent_events=num_recent_events
351+
)
352+
events = session_service.get_session(
353+
app_name=app_name, user_id=user_id, session_id=session.id, config=config
354+
).events
355+
assert len(events) == num_test_events - after_timestamp + 1

0 commit comments

Comments
 (0)