Skip to content

Commit a2c5f0a

Browse files
committed
Add additional task lifecycle tests
1 parent b4ccb40 commit a2c5f0a

File tree

2 files changed

+774
-0
lines changed

2 files changed

+774
-0
lines changed

tests/client/test_session_tasks.py

Lines changed: 389 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,389 @@
1+
"""Tests for client session task methods."""
2+
3+
import anyio
4+
import pytest
5+
6+
import mcp.types as types
7+
from examples.shared.in_memory_task_store import InMemoryTaskStore
8+
from mcp.client.session import ClientSession
9+
from mcp.server import Server
10+
from mcp.shared.memory import create_client_server_memory_streams
11+
12+
13+
@pytest.mark.anyio
14+
async def test_client_get_task_success():
15+
"""Test client.get_task() method with existing task."""
16+
task_store = InMemoryTaskStore()
17+
client_task_store = InMemoryTaskStore()
18+
server = Server("test", task_store=task_store)
19+
20+
# Create a task in the server's store
21+
task_id = "test-task-123"
22+
task_meta = types.TaskMetadata(taskId=task_id, keepAlive=60000)
23+
request = types.ClientRequest(types.PingRequest())
24+
await task_store.create_task(task_meta, "req-1", request.root)
25+
26+
async with create_client_server_memory_streams() as (client_streams, server_streams):
27+
client_read, client_write = client_streams
28+
server_read, server_write = server_streams
29+
30+
async with anyio.create_task_group() as tg:
31+
tg.start_soon(
32+
lambda: server.run(
33+
server_read,
34+
server_write,
35+
server.create_initialization_options(),
36+
)
37+
)
38+
39+
try:
40+
async with ClientSession(
41+
read_stream=client_read,
42+
write_stream=client_write,
43+
task_store=client_task_store,
44+
) as client_session:
45+
await client_session.initialize()
46+
47+
# Call get_task method
48+
result = await client_session.get_task(task_id)
49+
50+
assert result.taskId == task_id
51+
assert result.status == "submitted"
52+
assert result.keepAlive == 60000
53+
finally:
54+
tg.cancel_scope.cancel()
55+
56+
57+
@pytest.mark.anyio
58+
async def test_client_get_task_not_found():
59+
"""Test client.get_task() method when task doesn't exist."""
60+
task_store = InMemoryTaskStore()
61+
client_task_store = InMemoryTaskStore()
62+
server = Server("test", task_store=task_store)
63+
64+
async with create_client_server_memory_streams() as (client_streams, server_streams):
65+
client_read, client_write = client_streams
66+
server_read, server_write = server_streams
67+
68+
async with anyio.create_task_group() as tg:
69+
tg.start_soon(
70+
lambda: server.run(
71+
server_read,
72+
server_write,
73+
server.create_initialization_options(),
74+
)
75+
)
76+
77+
try:
78+
async with ClientSession(
79+
read_stream=client_read,
80+
write_stream=client_write,
81+
task_store=client_task_store,
82+
) as client_session:
83+
await client_session.initialize()
84+
85+
# Try to get non-existent task
86+
try:
87+
await client_session.get_task("non-existent")
88+
assert False, "Should have raised McpError"
89+
except Exception as e:
90+
assert "Task not found" in str(e) or str(types.INVALID_PARAMS) in str(e)
91+
finally:
92+
tg.cancel_scope.cancel()
93+
94+
95+
@pytest.mark.anyio
96+
async def test_client_get_task_result_success():
97+
"""Test client.get_task_result() method for completed task."""
98+
task_store = InMemoryTaskStore()
99+
client_task_store = InMemoryTaskStore()
100+
server = Server("test", task_store=task_store)
101+
102+
# Create a completed task with result
103+
task_id = "test-task-789"
104+
task_meta = types.TaskMetadata(taskId=task_id, keepAlive=60000)
105+
request = types.ClientRequest(types.PingRequest())
106+
await task_store.create_task(task_meta, "req-1", request.root)
107+
result = types.ServerResult(types.EmptyResult())
108+
await task_store.store_task_result(task_id, result.root)
109+
await task_store.update_task_status(task_id, "completed")
110+
111+
async with create_client_server_memory_streams() as (client_streams, server_streams):
112+
client_read, client_write = client_streams
113+
server_read, server_write = server_streams
114+
115+
async with anyio.create_task_group() as tg:
116+
tg.start_soon(
117+
lambda: server.run(
118+
server_read,
119+
server_write,
120+
server.create_initialization_options(),
121+
)
122+
)
123+
124+
try:
125+
async with ClientSession(
126+
read_stream=client_read,
127+
write_stream=client_write,
128+
task_store=client_task_store,
129+
) as client_session:
130+
await client_session.initialize()
131+
132+
# Get task result
133+
payload_result = await client_session.get_task_result(task_id, types.ServerResult)
134+
135+
# Verify we got the result back
136+
assert isinstance(payload_result.root, types.EmptyResult) # type: ignore[attr-defined]
137+
finally:
138+
tg.cancel_scope.cancel()
139+
140+
141+
@pytest.mark.anyio
142+
async def test_client_get_task_result_not_completed():
143+
"""Test client.get_task_result() method fails for non-completed task."""
144+
task_store = InMemoryTaskStore()
145+
client_task_store = InMemoryTaskStore()
146+
server = Server("test", task_store=task_store)
147+
148+
# Create a task in submitted state (not completed)
149+
task_id = "test-task-456"
150+
task_meta = types.TaskMetadata(taskId=task_id, keepAlive=60000)
151+
request = types.ClientRequest(types.PingRequest())
152+
await task_store.create_task(task_meta, "req-1", request.root)
153+
154+
async with create_client_server_memory_streams() as (client_streams, server_streams):
155+
client_read, client_write = client_streams
156+
server_read, server_write = server_streams
157+
158+
async with anyio.create_task_group() as tg:
159+
tg.start_soon(
160+
lambda: server.run(
161+
server_read,
162+
server_write,
163+
server.create_initialization_options(),
164+
)
165+
)
166+
167+
try:
168+
async with ClientSession(
169+
read_stream=client_read,
170+
write_stream=client_write,
171+
task_store=client_task_store,
172+
) as client_session:
173+
await client_session.initialize()
174+
175+
# Try to get result
176+
try:
177+
await client_session.get_task_result(task_id, types.ServerResult)
178+
assert False, "Should have raised McpError"
179+
except Exception as e:
180+
assert "not 'completed'" in str(e) or str(types.INVALID_PARAMS) in str(e)
181+
finally:
182+
tg.cancel_scope.cancel()
183+
184+
185+
@pytest.mark.anyio
186+
async def test_client_list_tasks_empty():
187+
"""Test client.list_tasks() method with no tasks."""
188+
task_store = InMemoryTaskStore()
189+
client_task_store = InMemoryTaskStore()
190+
server = Server("test", task_store=task_store)
191+
192+
async with create_client_server_memory_streams() as (client_streams, server_streams):
193+
client_read, client_write = client_streams
194+
server_read, server_write = server_streams
195+
196+
async with anyio.create_task_group() as tg:
197+
tg.start_soon(
198+
lambda: server.run(
199+
server_read,
200+
server_write,
201+
server.create_initialization_options(),
202+
)
203+
)
204+
205+
try:
206+
async with ClientSession(
207+
read_stream=client_read,
208+
write_stream=client_write,
209+
task_store=client_task_store,
210+
) as client_session:
211+
await client_session.initialize()
212+
213+
# List tasks
214+
result = await client_session.list_tasks()
215+
216+
assert result.tasks == []
217+
assert result.nextCursor is None
218+
finally:
219+
tg.cancel_scope.cancel()
220+
221+
222+
@pytest.mark.anyio
223+
async def test_client_list_tasks_with_tasks():
224+
"""Test client.list_tasks() method with multiple tasks."""
225+
task_store = InMemoryTaskStore()
226+
client_task_store = InMemoryTaskStore()
227+
server = Server("test", task_store=task_store)
228+
229+
# Create some tasks
230+
for i in range(3):
231+
task_id = f"task-{i}"
232+
task_meta = types.TaskMetadata(taskId=task_id, keepAlive=60000)
233+
request = types.ClientRequest(types.PingRequest())
234+
await task_store.create_task(task_meta, f"req-{i}", request.root)
235+
236+
async with create_client_server_memory_streams() as (client_streams, server_streams):
237+
client_read, client_write = client_streams
238+
server_read, server_write = server_streams
239+
240+
async with anyio.create_task_group() as tg:
241+
tg.start_soon(
242+
lambda: server.run(
243+
server_read,
244+
server_write,
245+
server.create_initialization_options(),
246+
)
247+
)
248+
249+
try:
250+
async with ClientSession(
251+
read_stream=client_read,
252+
write_stream=client_write,
253+
task_store=client_task_store,
254+
) as client_session:
255+
await client_session.initialize()
256+
257+
# List tasks
258+
result = await client_session.list_tasks()
259+
260+
assert len(result.tasks) == 3
261+
assert all(task.taskId.startswith("task-") for task in result.tasks)
262+
finally:
263+
tg.cancel_scope.cancel()
264+
265+
266+
@pytest.mark.anyio
267+
async def test_client_list_tasks_with_cursor():
268+
"""Test client.list_tasks() method with pagination cursor."""
269+
task_store = InMemoryTaskStore()
270+
client_task_store = InMemoryTaskStore()
271+
server = Server("test", task_store=task_store)
272+
273+
async with create_client_server_memory_streams() as (client_streams, server_streams):
274+
client_read, client_write = client_streams
275+
server_read, server_write = server_streams
276+
277+
async with anyio.create_task_group() as tg:
278+
tg.start_soon(
279+
lambda: server.run(
280+
server_read,
281+
server_write,
282+
server.create_initialization_options(),
283+
)
284+
)
285+
286+
try:
287+
async with ClientSession(
288+
read_stream=client_read,
289+
write_stream=client_write,
290+
task_store=client_task_store,
291+
) as client_session:
292+
await client_session.initialize()
293+
294+
# List tasks with invalid cursor should raise error
295+
try:
296+
await client_session.list_tasks(cursor="invalid-cursor")
297+
assert False, "Should have raised McpError"
298+
except Exception as e:
299+
assert "Invalid cursor" in str(e) or str(types.INVALID_PARAMS) in str(e)
300+
finally:
301+
tg.cancel_scope.cancel()
302+
303+
304+
@pytest.mark.anyio
305+
async def test_client_delete_task_success():
306+
"""Test client.delete_task() method."""
307+
task_store = InMemoryTaskStore()
308+
client_task_store = InMemoryTaskStore()
309+
server = Server("test", task_store=task_store)
310+
311+
# Create a task
312+
task_id = "task-to-delete"
313+
task_meta = types.TaskMetadata(taskId=task_id, keepAlive=60000)
314+
request = types.ClientRequest(types.PingRequest())
315+
await task_store.create_task(task_meta, "req-1", request.root)
316+
317+
# Verify task exists
318+
task = await task_store.get_task(task_id)
319+
assert task is not None
320+
321+
async with create_client_server_memory_streams() as (client_streams, server_streams):
322+
client_read, client_write = client_streams
323+
server_read, server_write = server_streams
324+
325+
async with anyio.create_task_group() as tg:
326+
tg.start_soon(
327+
lambda: server.run(
328+
server_read,
329+
server_write,
330+
server.create_initialization_options(),
331+
)
332+
)
333+
334+
try:
335+
async with ClientSession(
336+
read_stream=client_read,
337+
write_stream=client_write,
338+
task_store=client_task_store,
339+
) as client_session:
340+
await client_session.initialize()
341+
342+
# Delete task
343+
result = await client_session.delete_task(task_id)
344+
345+
assert result is not None
346+
finally:
347+
tg.cancel_scope.cancel()
348+
349+
# Verify task was deleted
350+
task = await task_store.get_task(task_id)
351+
assert task is None
352+
353+
354+
@pytest.mark.anyio
355+
async def test_client_delete_task_not_found():
356+
"""Test client.delete_task() method for non-existent task."""
357+
task_store = InMemoryTaskStore()
358+
client_task_store = InMemoryTaskStore()
359+
server = Server("test", task_store=task_store)
360+
361+
async with create_client_server_memory_streams() as (client_streams, server_streams):
362+
client_read, client_write = client_streams
363+
server_read, server_write = server_streams
364+
365+
async with anyio.create_task_group() as tg:
366+
tg.start_soon(
367+
lambda: server.run(
368+
server_read,
369+
server_write,
370+
server.create_initialization_options(),
371+
)
372+
)
373+
374+
try:
375+
async with ClientSession(
376+
read_stream=client_read,
377+
write_stream=client_write,
378+
task_store=client_task_store,
379+
) as client_session:
380+
await client_session.initialize()
381+
382+
# Try to delete non-existent task
383+
try:
384+
await client_session.delete_task("non-existent")
385+
assert False, "Should have raised McpError"
386+
except Exception as e:
387+
assert "Failed to delete task" in str(e) or str(types.INVALID_PARAMS) in str(e)
388+
finally:
389+
tg.cancel_scope.cancel()

0 commit comments

Comments
 (0)