Skip to content

Commit 417591f

Browse files
committed
address feedback
1 parent 2b3ec94 commit 417591f

File tree

5 files changed

+21
-40
lines changed

5 files changed

+21
-40
lines changed

src/a2a/server/apps/jsonrpc/jsonrpc_app.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
PREV_AGENT_CARD_WELL_KNOWN_PATH,
5252
)
5353
from a2a.utils.errors import MethodNotImplementedError
54-
from a2a.utils.helpers import apply_optional_awaitable
54+
from a2a.utils.helpers import maybe_await
5555

5656

5757
logger = logging.getLogger(__name__)
@@ -578,9 +578,7 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
578578

579579
card_to_serve = self.agent_card
580580
if self.card_modifier:
581-
card_to_serve = await apply_optional_awaitable(
582-
self.card_modifier, card_to_serve
583-
)
581+
card_to_serve = await maybe_await(self.card_modifier(card_to_serve))
584582

585583
return JSONResponse(
586584
card_to_serve.model_dump(
@@ -609,8 +607,8 @@ async def _handle_get_authenticated_extended_agent_card(
609607
context = self._context_builder.build(request)
610608
# If no base extended card is provided, pass the public card to the modifier
611609
base_card = card_to_serve if card_to_serve else self.agent_card
612-
card_to_serve = await apply_optional_awaitable(
613-
self.extended_card_modifier, base_card, context
610+
card_to_serve = await maybe_await(
611+
self.extended_card_modifier(base_card, context)
614612
)
615613

616614
if card_to_serve:

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
55
from typing import TYPE_CHECKING, Any
66

7-
from a2a.utils.helpers import apply_optional_awaitable
7+
from a2a.utils.helpers import maybe_await
88

99

1010
if TYPE_CHECKING:
@@ -153,9 +153,7 @@ async def handle_get_agent_card(
153153
"""
154154
card_to_serve = self.agent_card
155155
if self.card_modifier:
156-
card_to_serve = await apply_optional_awaitable(
157-
self.card_modifier, card_to_serve
158-
)
156+
card_to_serve = await maybe_await(self.card_modifier(card_to_serve))
159157

160158
return card_to_serve.model_dump(mode='json', exclude_none=True)
161159

@@ -187,13 +185,11 @@ async def handle_authenticated_agent_card(
187185

188186
if self.extended_card_modifier:
189187
context = self._context_builder.build(request)
190-
card_to_serve = await apply_optional_awaitable(
191-
self.extended_card_modifier, card_to_serve, context
188+
card_to_serve = await maybe_await(
189+
self.extended_card_modifier(card_to_serve, context)
192190
)
193191
elif self.card_modifier:
194-
card_to_serve = await apply_optional_awaitable(
195-
self.card_modifier, card_to_serve
196-
)
192+
card_to_serve = await maybe_await(self.card_modifier(card_to_serve))
197193

198194
return card_to_serve.model_dump(mode='json', exclude_none=True)
199195

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,7 @@
3434
from a2a.types import AgentCard, TaskNotFoundError
3535
from a2a.utils import proto_utils
3636
from a2a.utils.errors import ServerError
37-
from a2a.utils.helpers import (
38-
apply_optional_awaitable,
39-
validate,
40-
validate_async_generator,
41-
)
37+
from a2a.utils.helpers import maybe_await, validate, validate_async_generator
4238

4339

4440
logger = logging.getLogger(__name__)
@@ -344,9 +340,7 @@ async def GetAgentCard(
344340
"""Get the agent card for the agent served."""
345341
card_to_serve = self.agent_card
346342
if self.card_modifier:
347-
card_to_serve = await apply_optional_awaitable(
348-
self.card_modifier, card_to_serve
349-
)
343+
card_to_serve = await maybe_await(self.card_modifier(card_to_serve))
350344
return proto_utils.ToProto.agent_card(card_to_serve)
351345

352346
async def abort_context(

src/a2a/server/request_handlers/jsonrpc_handler.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
TaskStatusUpdateEvent,
4747
)
4848
from a2a.utils.errors import ServerError
49-
from a2a.utils.helpers import apply_optional_awaitable, validate
49+
from a2a.utils.helpers import maybe_await, validate
5050
from a2a.utils.telemetry import SpanKind, trace_class
5151

5252

@@ -451,13 +451,11 @@ async def get_authenticated_extended_card(
451451

452452
card_to_serve = base_card
453453
if self.extended_card_modifier and context:
454-
card_to_serve = await apply_optional_awaitable(
455-
self.extended_card_modifier, base_card, context
454+
card_to_serve = await maybe_await(
455+
self.extended_card_modifier(base_card, context)
456456
)
457457
elif self.card_modifier:
458-
card_to_serve = await apply_optional_awaitable(
459-
self.card_modifier, base_card
460-
)
458+
card_to_serve = await maybe_await(self.card_modifier(base_card))
461459

462460
return GetAuthenticatedExtendedCardResponse(
463461
root=GetAuthenticatedExtendedCardSuccessResponse(

src/a2a/utils/helpers.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import logging
77

88
from collections.abc import Awaitable, Callable
9-
from inspect import isawaitable
10-
from typing import Any, ParamSpec, TypeVar
9+
from typing import Any, TypeVar
1110
from uuid import uuid4
1211

1312
from a2a.types import (
@@ -26,7 +25,6 @@
2625

2726

2827
T = TypeVar('T')
29-
P = ParamSpec('P')
3028

3129

3230
logger = logging.getLogger(__name__)
@@ -375,11 +373,8 @@ def canonicalize_agent_card(agent_card: AgentCard) -> str:
375373
return json.dumps(cleaned_dict, separators=(',', ':'), sort_keys=True)
376374

377375

378-
async def apply_optional_awaitable(
379-
func: Callable[P, Awaitable[T] | T], *args: P.args, **kwargs: P.kwargs
380-
) -> T:
381-
"""Applies a function that may be sync or async and returns the result."""
382-
result = func(*args, **kwargs)
383-
if isawaitable(result):
384-
return await result
385-
return result
376+
async def maybe_await(value: T | Awaitable[T]) -> T:
377+
"""Awaits a value if it's awaitable, otherwise simply provides it back."""
378+
if inspect.isawaitable(value):
379+
return await value
380+
return value

0 commit comments

Comments
 (0)