Skip to content

Commit 2909d66

Browse files
committed
test: add e2e client-server test
Tests basic functionality with real client and server with real handlers, only agent executor is provided in test as it'd be in a real usage.
1 parent b306e44 commit 2909d66

1 file changed

Lines changed: 313 additions & 0 deletions

File tree

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
from collections.abc import AsyncGenerator
2+
from typing import NamedTuple
3+
4+
import grpc
5+
import httpx
6+
import pytest
7+
import pytest_asyncio
8+
9+
from a2a.client.transports import (
10+
ClientTransport,
11+
GrpcTransport,
12+
JsonRpcTransport,
13+
RestTransport,
14+
)
15+
from a2a.server.agent_execution import AgentExecutor, RequestContext
16+
from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication
17+
from a2a.server.events import EventQueue
18+
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
19+
from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler
20+
from a2a.server.tasks import TaskUpdater
21+
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
22+
from a2a.types import (
23+
AgentCapabilities,
24+
AgentCard,
25+
AgentInterface,
26+
GetTaskRequest,
27+
ListTasksRequest,
28+
Message,
29+
Part,
30+
Role,
31+
SendMessageConfiguration,
32+
SendMessageRequest,
33+
TaskState,
34+
a2a_pb2_grpc,
35+
)
36+
from a2a.utils import TRANSPORT_GRPC, TRANSPORT_HTTP_JSON, TRANSPORT_JSONRPC
37+
38+
39+
class MockAgentExecutor(AgentExecutor):
40+
async def execute(self, context: RequestContext, event_queue: EventQueue):
41+
task_updater = TaskUpdater(
42+
event_queue,
43+
context.task_id,
44+
context.context_id,
45+
)
46+
await task_updater.update_status(TaskState.TASK_STATE_SUBMITTED)
47+
await task_updater.update_status(TaskState.TASK_STATE_WORKING)
48+
await task_updater.update_status(
49+
TaskState.TASK_STATE_COMPLETED,
50+
message=task_updater.new_agent_message([Part(text='done')]),
51+
)
52+
53+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
54+
raise NotImplementedError('Cancellation is not supported')
55+
56+
57+
@pytest.fixture
58+
def agent_card() -> AgentCard:
59+
return AgentCard(
60+
name='Integration Agent',
61+
description='Real in-memory integration testing.',
62+
version='1.0.0',
63+
capabilities=AgentCapabilities(
64+
streaming=True, push_notifications=False
65+
),
66+
skills=[],
67+
default_input_modes=['text/plain'],
68+
default_output_modes=['text/plain'],
69+
supported_interfaces=[
70+
AgentInterface(
71+
protocol_binding=TRANSPORT_HTTP_JSON,
72+
url='http://testserver',
73+
),
74+
AgentInterface(
75+
protocol_binding=TRANSPORT_JSONRPC,
76+
url='http://testserver',
77+
),
78+
AgentInterface(
79+
protocol_binding=TRANSPORT_GRPC,
80+
url='localhost:50051',
81+
),
82+
],
83+
)
84+
85+
86+
class TransportSetup(NamedTuple):
87+
"""Holds the transport and task_store for a given test."""
88+
89+
transport: ClientTransport
90+
task_store: InMemoryTaskStore
91+
92+
93+
@pytest.fixture
94+
def base_e2e_setup():
95+
task_store = InMemoryTaskStore()
96+
handler = DefaultRequestHandler(
97+
agent_executor=MockAgentExecutor(),
98+
task_store=task_store,
99+
queue_manager=InMemoryQueueManager(),
100+
)
101+
return task_store, handler
102+
103+
104+
@pytest.fixture
105+
def rest_setup(agent_card, base_e2e_setup) -> TransportSetup:
106+
task_store, handler = base_e2e_setup
107+
app_builder = A2ARESTFastAPIApplication(agent_card, handler)
108+
app = app_builder.build()
109+
httpx_client = httpx.AsyncClient(
110+
transport=httpx.ASGITransport(app=app), base_url='http://testserver'
111+
)
112+
transport = RestTransport(httpx_client=httpx_client, agent_card=agent_card)
113+
return TransportSetup(
114+
transport=transport,
115+
task_store=task_store,
116+
)
117+
118+
119+
@pytest.fixture
120+
def jsonrpc_setup(agent_card, base_e2e_setup) -> TransportSetup:
121+
task_store, handler = base_e2e_setup
122+
app_builder = A2AFastAPIApplication(
123+
agent_card, handler, extended_agent_card=agent_card
124+
)
125+
app = app_builder.build()
126+
httpx_client = httpx.AsyncClient(
127+
transport=httpx.ASGITransport(app=app), base_url='http://testserver'
128+
)
129+
transport = JsonRpcTransport(
130+
httpx_client=httpx_client, agent_card=agent_card
131+
)
132+
return TransportSetup(
133+
transport=transport,
134+
task_store=task_store,
135+
)
136+
137+
138+
@pytest_asyncio.fixture
139+
async def grpc_setup(
140+
agent_card: AgentCard, base_e2e_setup
141+
) -> AsyncGenerator[TransportSetup, None]:
142+
task_store, handler = base_e2e_setup
143+
server = grpc.aio.server()
144+
port = server.add_insecure_port('[::]:0')
145+
server_address = f'localhost:{port}'
146+
147+
grpc_agent_card = AgentCard()
148+
grpc_agent_card.CopyFrom(agent_card)
149+
150+
# Update the gRPC interface dynamically based on the assigned port
151+
for interface in grpc_agent_card.supported_interfaces:
152+
if interface.protocol_binding == TRANSPORT_GRPC:
153+
interface.url = server_address
154+
break
155+
else:
156+
raise ValueError('No gRPC interface found in agent card')
157+
158+
servicer = GrpcHandler(grpc_agent_card, handler)
159+
a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server)
160+
await server.start()
161+
162+
channel = grpc.aio.insecure_channel(server_address)
163+
transport = GrpcTransport(agent_card=grpc_agent_card, channel=channel)
164+
yield TransportSetup(
165+
transport=transport,
166+
task_store=task_store,
167+
)
168+
169+
await channel.close()
170+
await server.stop(0)
171+
172+
173+
@pytest.fixture(
174+
params=[
175+
pytest.param('rest_setup', id='REST'),
176+
pytest.param('jsonrpc_setup', id='JSON-RPC'),
177+
pytest.param('grpc_setup', id='gRPC'),
178+
]
179+
)
180+
def transport_setups(request) -> TransportSetup:
181+
"""Parametrized fixture that runs tests against all supported transports."""
182+
return request.getfixturevalue(request.param)
183+
184+
185+
@pytest.mark.asyncio
186+
async def test_end_to_end_send_message_blocking(transport_setups):
187+
transport = transport_setups.transport
188+
189+
message_to_send = Message(
190+
role=Role.ROLE_USER,
191+
message_id='msg-e2e-blocking',
192+
parts=[Part(text='Run dummy agent!')],
193+
)
194+
configuration = SendMessageConfiguration(blocking=True)
195+
params = SendMessageRequest(
196+
message=message_to_send, configuration=configuration
197+
)
198+
199+
response = await transport.send_message(request=params)
200+
201+
task = response.task
202+
assert task.id
203+
assert task.status.state == TaskState.TASK_STATE_COMPLETED
204+
205+
206+
@pytest.mark.asyncio
207+
async def test_end_to_end_send_message_non_blocking(transport_setups):
208+
transport = transport_setups.transport
209+
210+
message_to_send = Message(
211+
role=Role.ROLE_USER,
212+
message_id='msg-e2e-non-blocking',
213+
parts=[Part(text='Run dummy agent!')],
214+
)
215+
configuration = SendMessageConfiguration(blocking=False)
216+
params = SendMessageRequest(
217+
message=message_to_send, configuration=configuration
218+
)
219+
220+
response = await transport.send_message(request=params)
221+
222+
task = response.task
223+
assert task.id
224+
225+
226+
@pytest.mark.asyncio
227+
async def test_end_to_end_send_message_streaming(transport_setups):
228+
transport = transport_setups.transport
229+
230+
message_to_send = Message(
231+
role=Role.ROLE_USER,
232+
message_id='msg-e2e-streaming',
233+
parts=[Part(text='Run dummy agent!')],
234+
)
235+
params = SendMessageRequest(message=message_to_send)
236+
237+
events = [
238+
event
239+
async for event in transport.send_message_streaming(request=params)
240+
]
241+
242+
assert len(events) > 0
243+
final_event = events[-1]
244+
245+
assert final_event.HasField('status_update')
246+
assert final_event.status_update.task_id
247+
assert (
248+
final_event.status_update.status.state == TaskState.TASK_STATE_COMPLETED
249+
)
250+
251+
252+
@pytest.mark.asyncio
253+
async def test_end_to_end_get_task(transport_setups):
254+
transport = transport_setups.transport
255+
256+
message_to_send = Message(
257+
role=Role.ROLE_USER,
258+
message_id='msg-e2e-get',
259+
parts=[Part(text='Test Get Task')],
260+
)
261+
response = await transport.send_message(
262+
request=SendMessageRequest(message=message_to_send)
263+
)
264+
task_id = response.task.id
265+
266+
get_request = GetTaskRequest(id=task_id)
267+
retrieved_task = await transport.get_task(request=get_request)
268+
269+
assert retrieved_task.id == task_id
270+
assert retrieved_task.status.state in {
271+
TaskState.TASK_STATE_SUBMITTED,
272+
TaskState.TASK_STATE_WORKING,
273+
TaskState.TASK_STATE_COMPLETED,
274+
}
275+
276+
277+
@pytest.mark.asyncio
278+
async def test_end_to_end_list_tasks(transport_setups):
279+
transport = transport_setups.transport
280+
281+
total_items = 6
282+
page_size = 2
283+
284+
for i in range(total_items):
285+
await transport.send_message(
286+
request=SendMessageRequest(
287+
message=Message(
288+
role=Role.ROLE_USER,
289+
message_id=f'msg-e2e-list-{i}',
290+
parts=[Part(text=f'Test List Tasks {i}')],
291+
),
292+
configuration=SendMessageConfiguration(blocking=False),
293+
)
294+
)
295+
296+
list_request = ListTasksRequest(page_size=page_size)
297+
298+
unique_task_ids = set()
299+
token = None
300+
301+
while token != '':
302+
if token:
303+
list_request.page_token = token
304+
305+
list_response = await transport.list_tasks(request=list_request)
306+
assert 0 < len(list_response.tasks) <= page_size
307+
308+
for task in list_response.tasks:
309+
unique_task_ids.add(task.id)
310+
311+
token = list_response.next_page_token
312+
313+
assert len(unique_task_ids) == total_items

0 commit comments

Comments
 (0)