Skip to content

Commit c7768cd

Browse files
committed
make card modifier and extended card modifier async
1 parent ffe31e2 commit c7768cd

11 files changed

Lines changed: 226 additions & 38 deletions

File tree

src/a2a/server/apps/jsonrpc/fastapi_app.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import TYPE_CHECKING, Any
55

66

@@ -72,9 +72,10 @@ def __init__( # noqa: PLR0913
7272
http_handler: RequestHandler,
7373
extended_agent_card: AgentCard | None = None,
7474
context_builder: CallContextBuilder | None = None,
75-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
75+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
76+
| None = None,
7677
extended_card_modifier: Callable[
77-
[AgentCard, ServerCallContext], AgentCard
78+
[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard
7879
]
7980
| None = None,
8081
max_content_length: int | None = 10 * 1024 * 1024, # 10MB

src/a2a/server/apps/jsonrpc/jsonrpc_app.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import traceback
55

66
from abc import ABC, abstractmethod
7-
from collections.abc import AsyncGenerator, Callable
7+
from collections.abc import AsyncGenerator, Awaitable, Callable
8+
from inspect import isawaitable
89
from typing import TYPE_CHECKING, Any
910

1011
from pydantic import ValidationError
@@ -178,9 +179,10 @@ def __init__( # noqa: PLR0913
178179
http_handler: RequestHandler,
179180
extended_agent_card: AgentCard | None = None,
180181
context_builder: CallContextBuilder | None = None,
181-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
182+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
183+
| None = None,
182184
extended_card_modifier: Callable[
183-
[AgentCard, ServerCallContext], AgentCard
185+
[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard
184186
]
185187
| None = None,
186188
max_content_length: int | None = 10 * 1024 * 1024, # 10MB
@@ -576,7 +578,8 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
576578

577579
card_to_serve = self.agent_card
578580
if self.card_modifier:
579-
card_to_serve = self.card_modifier(card_to_serve)
581+
result = self.card_modifier(card_to_serve)
582+
card_to_serve = await result if isawaitable(result) else result
580583

581584
return JSONResponse(
582585
card_to_serve.model_dump(
@@ -605,7 +608,8 @@ async def _handle_get_authenticated_extended_agent_card(
605608
context = self._context_builder.build(request)
606609
# If no base extended card is provided, pass the public card to the modifier
607610
base_card = card_to_serve if card_to_serve else self.agent_card
608-
card_to_serve = self.extended_card_modifier(base_card, context)
611+
result = await self.extended_card_modifier(base_card, context)
612+
card_to_serve = await result if isawaitable(result) else result
609613

610614
if card_to_serve:
611615
return JSONResponse(

src/a2a/server/apps/jsonrpc/starlette_app.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import TYPE_CHECKING, Any
55

66

@@ -54,9 +54,10 @@ def __init__( # noqa: PLR0913
5454
http_handler: RequestHandler,
5555
extended_agent_card: AgentCard | None = None,
5656
context_builder: CallContextBuilder | None = None,
57-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
57+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
58+
| None = None,
5859
extended_card_modifier: Callable[
59-
[AgentCard, ServerCallContext], AgentCard
60+
[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard
6061
]
6162
| None = None,
6263
max_content_length: int | None = 10 * 1024 * 1024, # 10MB

src/a2a/server/apps/rest/fastapi_app.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import TYPE_CHECKING, Any
55

66

@@ -49,9 +49,10 @@ def __init__( # noqa: PLR0913
4949
http_handler: RequestHandler,
5050
extended_agent_card: AgentCard | None = None,
5151
context_builder: CallContextBuilder | None = None,
52-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
52+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
53+
| None = None,
5354
extended_card_modifier: Callable[
54-
[AgentCard, ServerCallContext], AgentCard
55+
[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard
5556
]
5657
| None = None,
5758
):

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33

44
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
5+
from inspect import isawaitable
56
from typing import TYPE_CHECKING, Any
67

78

@@ -58,9 +59,10 @@ def __init__( # noqa: PLR0913
5859
http_handler: RequestHandler,
5960
extended_agent_card: AgentCard | None = None,
6061
context_builder: CallContextBuilder | None = None,
61-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
62+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
63+
| None = None,
6264
extended_card_modifier: Callable[
63-
[AgentCard, ServerCallContext], AgentCard
65+
[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard
6466
]
6567
| None = None,
6668
):
@@ -150,7 +152,8 @@ async def handle_get_agent_card(
150152
"""
151153
card_to_serve = self.agent_card
152154
if self.card_modifier:
153-
card_to_serve = self.card_modifier(card_to_serve)
155+
result = self.card_modifier(card_to_serve)
156+
card_to_serve = await result if isawaitable(result) else result
154157

155158
return card_to_serve.model_dump(mode='json', exclude_none=True)
156159

@@ -182,9 +185,11 @@ async def handle_authenticated_agent_card(
182185

183186
if self.extended_card_modifier:
184187
context = self._context_builder.build(request)
185-
card_to_serve = self.extended_card_modifier(card_to_serve, context)
188+
result = self.extended_card_modifier(card_to_serve, context)
189+
card_to_serve = await result if isawaitable(result) else result
186190
elif self.card_modifier:
187-
card_to_serve = self.card_modifier(card_to_serve)
191+
result = self.card_modifier(card_to_serve)
192+
card_to_serve = await result if isawaitable(result) else result
188193

189194
return card_to_serve.model_dump(mode='json', exclude_none=True)
190195

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import logging
44

55
from abc import ABC, abstractmethod
6-
from collections.abc import AsyncIterable, Sequence
6+
from collections.abc import AsyncIterable, Awaitable, Sequence
7+
from inspect import isawaitable
78

89

910
try:
@@ -89,7 +90,8 @@ def __init__(
8990
agent_card: AgentCard,
9091
request_handler: RequestHandler,
9192
context_builder: CallContextBuilder | None = None,
92-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
93+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
94+
| None = None,
9395
):
9496
"""Initializes the GrpcHandler.
9597
@@ -339,7 +341,8 @@ async def GetAgentCard(
339341
"""Get the agent card for the agent served."""
340342
card_to_serve = self.agent_card
341343
if self.card_modifier:
342-
card_to_serve = self.card_modifier(card_to_serve)
344+
result = self.card_modifier(card_to_serve)
345+
card_to_serve = await result if isawaitable(result) else result
343346
return proto_utils.ToProto.agent_card(card_to_serve)
344347

345348
async def abort_context(

src/a2a/server/request_handlers/jsonrpc_handler.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22

3-
from collections.abc import AsyncIterable, Callable
3+
from collections.abc import AsyncIterable, Awaitable, Callable
4+
from inspect import isawaitable
45

56
from a2a.server.context import ServerCallContext
67
from a2a.server.request_handlers.request_handler import RequestHandler
@@ -63,10 +64,11 @@ def __init__(
6364
request_handler: RequestHandler,
6465
extended_agent_card: AgentCard | None = None,
6566
extended_card_modifier: Callable[
66-
[AgentCard, ServerCallContext], AgentCard
67+
[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard
6768
]
6869
| None = None,
69-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
70+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
71+
| None = None,
7072
):
7173
"""Initializes the JSONRPCHandler.
7274
@@ -450,9 +452,11 @@ async def get_authenticated_extended_card(
450452

451453
card_to_serve = base_card
452454
if self.extended_card_modifier and context:
453-
card_to_serve = self.extended_card_modifier(base_card, context)
455+
result = self.extended_card_modifier(base_card, context)
456+
card_to_serve = await result if isawaitable(result) else result
454457
elif self.card_modifier:
455-
card_to_serve = self.card_modifier(base_card)
458+
result = self.card_modifier(base_card)
459+
card_to_serve = await result if isawaitable(result) else result
456460

457461
return GetAuthenticatedExtendedCardResponse(
458462
root=GetAuthenticatedExtendedCardSuccessResponse(

tests/server/request_handlers/test_grpc_handler.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,34 @@ async def test_get_agent_card_with_modifier(
209209
) -> None:
210210
"""Test GetAgentCard call with a card_modifier."""
211211

212+
async def modifier(card: types.AgentCard) -> types.AgentCard:
213+
modified_card = card.model_copy(deep=True)
214+
modified_card.name = 'Modified gRPC Agent'
215+
return modified_card
216+
217+
grpc_handler_modified = GrpcHandler(
218+
agent_card=sample_agent_card,
219+
request_handler=mock_request_handler,
220+
card_modifier=modifier,
221+
)
222+
223+
request_proto = a2a_pb2.GetAgentCardRequest()
224+
response = await grpc_handler_modified.GetAgentCard(
225+
request_proto, mock_grpc_context
226+
)
227+
228+
assert response.name == 'Modified gRPC Agent'
229+
assert response.version == sample_agent_card.version
230+
231+
232+
@pytest.mark.asyncio
233+
async def test_get_agent_card_with_modifier_sync(
234+
mock_request_handler: AsyncMock,
235+
sample_agent_card: types.AgentCard,
236+
mock_grpc_context: AsyncMock,
237+
) -> None:
238+
"""Test GetAgentCard call with a synchronous card_modifier."""
239+
212240
def modifier(card: types.AgentCard) -> types.AgentCard:
213241
modified_card = card.model_copy(deep=True)
214242
modified_card.name = 'Modified gRPC Agent'

tests/server/request_handlers/test_jsonrpc_handler.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,6 +1295,57 @@ async def test_get_authenticated_extended_card_with_modifier(self) -> None:
12951295
skills=[],
12961296
)
12971297

1298+
async def modifier(
1299+
card: AgentCard, context: ServerCallContext
1300+
) -> AgentCard:
1301+
modified_card = card.model_copy(deep=True)
1302+
modified_card.name = 'Modified Card'
1303+
modified_card.description = (
1304+
f'Modified for context: {context.state.get("foo")}'
1305+
)
1306+
return modified_card
1307+
1308+
handler = JSONRPCHandler(
1309+
self.mock_agent_card,
1310+
mock_request_handler,
1311+
extended_agent_card=mock_base_card,
1312+
extended_card_modifier=modifier,
1313+
)
1314+
request = GetAuthenticatedExtendedCardRequest(id='ext-card-req-mod')
1315+
call_context = ServerCallContext(state={'foo': 'bar'})
1316+
1317+
# Act
1318+
response: GetAuthenticatedExtendedCardResponse = (
1319+
await handler.get_authenticated_extended_card(request, call_context)
1320+
)
1321+
1322+
# Assert
1323+
self.assertIsInstance(
1324+
response.root, GetAuthenticatedExtendedCardSuccessResponse
1325+
)
1326+
self.assertEqual(response.root.id, 'ext-card-req-mod')
1327+
modified_card = response.root.result
1328+
self.assertEqual(modified_card.name, 'Modified Card')
1329+
self.assertEqual(modified_card.description, 'Modified for context: bar')
1330+
self.assertEqual(modified_card.version, '1.0')
1331+
1332+
async def test_get_authenticated_extended_card_with_modifier_sync(
1333+
self,
1334+
) -> None:
1335+
"""Test successful retrieval of a synchronously dynamically modified extended agent card."""
1336+
# Arrange
1337+
mock_request_handler = AsyncMock(spec=DefaultRequestHandler)
1338+
mock_base_card = AgentCard(
1339+
name='Base Card',
1340+
description='Base details',
1341+
url='http://agent.example.com/api',
1342+
version='1.0',
1343+
capabilities=AgentCapabilities(),
1344+
default_input_modes=['text/plain'],
1345+
default_output_modes=['application/json'],
1346+
skills=[],
1347+
)
1348+
12981349
def modifier(card: AgentCard, context: ServerCallContext) -> AgentCard:
12991350
modified_card = card.model_copy(deep=True)
13001351
modified_card.name = 'Modified Card'

0 commit comments

Comments
 (0)