Skip to content

Commit a802500

Browse files
cbugwadia32ishymko
andauthored
feat: support async card modifiers (#654)
# Description Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [X] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [X] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [X] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [X] Appropriate docs were updated (if necessary) Fixes #647 🦕 Release-As: 0.3.23 --------- Co-authored-by: Ivan Shymko <ishymko@google.com>
1 parent beb2b4b commit a802500

File tree

11 files changed

+229
-31
lines changed

11 files changed

+229
-31
lines changed

src/a2a/server/apps/jsonrpc/fastapi_app.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import TYPE_CHECKING, Any
55

66

@@ -72,9 +72,10 @@ def __init__( # noqa: PLR0913
7272
http_handler: RequestHandler,
7373
extended_agent_card: AgentCard | None = None,
7474
context_builder: CallContextBuilder | None = None,
75-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
75+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
76+
| None = None,
7677
extended_card_modifier: Callable[
77-
[AgentCard, ServerCallContext], AgentCard
78+
[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard
7879
]
7980
| None = None,
8081
max_content_length: int | None = 10 * 1024 * 1024, # 10MB

src/a2a/server/apps/jsonrpc/jsonrpc_app.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import traceback
55

66
from abc import ABC, abstractmethod
7-
from collections.abc import AsyncGenerator, Callable
7+
from collections.abc import AsyncGenerator, Awaitable, Callable
88
from typing import TYPE_CHECKING, Any
99

1010
from pydantic import ValidationError
@@ -51,6 +51,7 @@
5151
PREV_AGENT_CARD_WELL_KNOWN_PATH,
5252
)
5353
from a2a.utils.errors import MethodNotImplementedError
54+
from a2a.utils.helpers import maybe_await
5455

5556

5657
logger = logging.getLogger(__name__)
@@ -178,9 +179,10 @@ def __init__( # noqa: PLR0913
178179
http_handler: RequestHandler,
179180
extended_agent_card: AgentCard | None = None,
180181
context_builder: CallContextBuilder | None = None,
181-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
182+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
183+
| None = None,
182184
extended_card_modifier: Callable[
183-
[AgentCard, ServerCallContext], AgentCard
185+
[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard
184186
]
185187
| None = None,
186188
max_content_length: int | None = 10 * 1024 * 1024, # 10MB
@@ -576,7 +578,7 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
576578

577579
card_to_serve = self.agent_card
578580
if self.card_modifier:
579-
card_to_serve = self.card_modifier(card_to_serve)
581+
card_to_serve = await maybe_await(self.card_modifier(card_to_serve))
580582

581583
return JSONResponse(
582584
card_to_serve.model_dump(
@@ -605,7 +607,9 @@ async def _handle_get_authenticated_extended_agent_card(
605607
context = self._context_builder.build(request)
606608
# If no base extended card is provided, pass the public card to the modifier
607609
base_card = card_to_serve if card_to_serve else self.agent_card
608-
card_to_serve = self.extended_card_modifier(base_card, context)
610+
card_to_serve = await maybe_await(
611+
self.extended_card_modifier(base_card, context)
612+
)
609613

610614
if card_to_serve:
611615
return JSONResponse(

src/a2a/server/apps/jsonrpc/starlette_app.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import TYPE_CHECKING, Any
55

66

@@ -54,9 +54,10 @@ def __init__( # noqa: PLR0913
5454
http_handler: RequestHandler,
5555
extended_agent_card: AgentCard | None = None,
5656
context_builder: CallContextBuilder | None = None,
57-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
57+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
58+
| None = None,
5859
extended_card_modifier: Callable[
59-
[AgentCard, ServerCallContext], AgentCard
60+
[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard
6061
]
6162
| None = None,
6263
max_content_length: int | None = 10 * 1024 * 1024, # 10MB

src/a2a/server/apps/rest/fastapi_app.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import TYPE_CHECKING, Any
55

66

@@ -49,9 +49,10 @@ def __init__( # noqa: PLR0913
4949
http_handler: RequestHandler,
5050
extended_agent_card: AgentCard | None = None,
5151
context_builder: CallContextBuilder | None = None,
52-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
52+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
53+
| None = None,
5354
extended_card_modifier: Callable[
54-
[AgentCard, ServerCallContext], AgentCard
55+
[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard
5556
]
5657
| None = None,
5758
):

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
55
from typing import TYPE_CHECKING, Any
66

7+
from a2a.utils.helpers import maybe_await
8+
79

810
if TYPE_CHECKING:
911
from sse_starlette.sse import EventSourceResponse
@@ -58,9 +60,10 @@ def __init__( # noqa: PLR0913
5860
http_handler: RequestHandler,
5961
extended_agent_card: AgentCard | None = None,
6062
context_builder: CallContextBuilder | None = None,
61-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
63+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
64+
| None = None,
6265
extended_card_modifier: Callable[
63-
[AgentCard, ServerCallContext], AgentCard
66+
[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard
6467
]
6568
| None = None,
6669
):
@@ -150,7 +153,7 @@ async def handle_get_agent_card(
150153
"""
151154
card_to_serve = self.agent_card
152155
if self.card_modifier:
153-
card_to_serve = self.card_modifier(card_to_serve)
156+
card_to_serve = await maybe_await(self.card_modifier(card_to_serve))
154157

155158
return card_to_serve.model_dump(mode='json', exclude_none=True)
156159

@@ -182,9 +185,11 @@ async def handle_authenticated_agent_card(
182185

183186
if self.extended_card_modifier:
184187
context = self._context_builder.build(request)
185-
card_to_serve = self.extended_card_modifier(card_to_serve, context)
188+
card_to_serve = await maybe_await(
189+
self.extended_card_modifier(card_to_serve, context)
190+
)
186191
elif self.card_modifier:
187-
card_to_serve = self.card_modifier(card_to_serve)
192+
card_to_serve = await maybe_await(self.card_modifier(card_to_serve))
188193

189194
return card_to_serve.model_dump(mode='json', exclude_none=True)
190195

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44

55
from abc import ABC, abstractmethod
6-
from collections.abc import AsyncIterable, Sequence
6+
from collections.abc import AsyncIterable, Awaitable, Sequence
77

88

99
try:
@@ -34,7 +34,7 @@
3434
from a2a.types import AgentCard, TaskNotFoundError
3535
from a2a.utils import proto_utils
3636
from a2a.utils.errors import ServerError
37-
from a2a.utils.helpers import validate, validate_async_generator
37+
from a2a.utils.helpers import maybe_await, validate, validate_async_generator
3838

3939

4040
logger = logging.getLogger(__name__)
@@ -89,7 +89,8 @@ def __init__(
8989
agent_card: AgentCard,
9090
request_handler: RequestHandler,
9191
context_builder: CallContextBuilder | None = None,
92-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
92+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
93+
| None = None,
9394
):
9495
"""Initializes the GrpcHandler.
9596
@@ -339,7 +340,7 @@ async def GetAgentCard(
339340
"""Get the agent card for the agent served."""
340341
card_to_serve = self.agent_card
341342
if self.card_modifier:
342-
card_to_serve = self.card_modifier(card_to_serve)
343+
card_to_serve = await maybe_await(self.card_modifier(card_to_serve))
343344
return proto_utils.ToProto.agent_card(card_to_serve)
344345

345346
async def abort_context(

src/a2a/server/request_handlers/jsonrpc_handler.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from collections.abc import AsyncIterable, Callable
3+
from collections.abc import AsyncIterable, Awaitable, Callable
44

55
from a2a.server.context import ServerCallContext
66
from a2a.server.request_handlers.request_handler import RequestHandler
@@ -46,7 +46,7 @@
4646
TaskStatusUpdateEvent,
4747
)
4848
from a2a.utils.errors import ServerError
49-
from a2a.utils.helpers import validate
49+
from a2a.utils.helpers import maybe_await, validate
5050
from a2a.utils.telemetry import SpanKind, trace_class
5151

5252

@@ -63,10 +63,11 @@ def __init__(
6363
request_handler: RequestHandler,
6464
extended_agent_card: AgentCard | None = None,
6565
extended_card_modifier: Callable[
66-
[AgentCard, ServerCallContext], AgentCard
66+
[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard
6767
]
6868
| None = None,
69-
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
69+
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
70+
| None = None,
7071
):
7172
"""Initializes the JSONRPCHandler.
7273
@@ -450,9 +451,11 @@ async def get_authenticated_extended_card(
450451

451452
card_to_serve = base_card
452453
if self.extended_card_modifier and context:
453-
card_to_serve = self.extended_card_modifier(base_card, context)
454+
card_to_serve = await maybe_await(
455+
self.extended_card_modifier(base_card, context)
456+
)
454457
elif self.card_modifier:
455-
card_to_serve = self.card_modifier(base_card)
458+
card_to_serve = await maybe_await(self.card_modifier(base_card))
456459

457460
return GetAuthenticatedExtendedCardResponse(
458461
root=GetAuthenticatedExtendedCardSuccessResponse(

src/a2a/utils/helpers.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import json
66
import logging
77

8-
from collections.abc import Callable
9-
from typing import Any
8+
from collections.abc import Awaitable, Callable
9+
from typing import Any, TypeVar
1010
from uuid import uuid4
1111

1212
from a2a.types import (
@@ -24,6 +24,9 @@
2424
from a2a.utils.telemetry import trace_function
2525

2626

27+
T = TypeVar('T')
28+
29+
2730
logger = logging.getLogger(__name__)
2831

2932

@@ -368,3 +371,10 @@ def canonicalize_agent_card(agent_card: AgentCard) -> str:
368371
# Recursively remove empty values
369372
cleaned_dict = _clean_empty(card_dict)
370373
return json.dumps(cleaned_dict, separators=(',', ':'), sort_keys=True)
374+
375+
376+
async def maybe_await(value: T | Awaitable[T]) -> T:
377+
"""Awaits a value if it's awaitable, otherwise simply provides it back."""
378+
if inspect.isawaitable(value):
379+
return await value
380+
return value

tests/server/request_handlers/test_grpc_handler.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,34 @@ async def test_get_agent_card_with_modifier(
209209
) -> None:
210210
"""Test GetAgentCard call with a card_modifier."""
211211

212+
async def modifier(card: types.AgentCard) -> types.AgentCard:
213+
modified_card = card.model_copy(deep=True)
214+
modified_card.name = 'Modified gRPC Agent'
215+
return modified_card
216+
217+
grpc_handler_modified = GrpcHandler(
218+
agent_card=sample_agent_card,
219+
request_handler=mock_request_handler,
220+
card_modifier=modifier,
221+
)
222+
223+
request_proto = a2a_pb2.GetAgentCardRequest()
224+
response = await grpc_handler_modified.GetAgentCard(
225+
request_proto, mock_grpc_context
226+
)
227+
228+
assert response.name == 'Modified gRPC Agent'
229+
assert response.version == sample_agent_card.version
230+
231+
232+
@pytest.mark.asyncio
233+
async def test_get_agent_card_with_modifier_sync(
234+
mock_request_handler: AsyncMock,
235+
sample_agent_card: types.AgentCard,
236+
mock_grpc_context: AsyncMock,
237+
) -> None:
238+
"""Test GetAgentCard call with a synchronous card_modifier."""
239+
212240
def modifier(card: types.AgentCard) -> types.AgentCard:
213241
modified_card = card.model_copy(deep=True)
214242
modified_card.name = 'Modified gRPC Agent'

tests/server/request_handlers/test_jsonrpc_handler.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,6 +1295,57 @@ async def test_get_authenticated_extended_card_with_modifier(self) -> None:
12951295
skills=[],
12961296
)
12971297

1298+
async def modifier(
1299+
card: AgentCard, context: ServerCallContext
1300+
) -> AgentCard:
1301+
modified_card = card.model_copy(deep=True)
1302+
modified_card.name = 'Modified Card'
1303+
modified_card.description = (
1304+
f'Modified for context: {context.state.get("foo")}'
1305+
)
1306+
return modified_card
1307+
1308+
handler = JSONRPCHandler(
1309+
self.mock_agent_card,
1310+
mock_request_handler,
1311+
extended_agent_card=mock_base_card,
1312+
extended_card_modifier=modifier,
1313+
)
1314+
request = GetAuthenticatedExtendedCardRequest(id='ext-card-req-mod')
1315+
call_context = ServerCallContext(state={'foo': 'bar'})
1316+
1317+
# Act
1318+
response: GetAuthenticatedExtendedCardResponse = (
1319+
await handler.get_authenticated_extended_card(request, call_context)
1320+
)
1321+
1322+
# Assert
1323+
self.assertIsInstance(
1324+
response.root, GetAuthenticatedExtendedCardSuccessResponse
1325+
)
1326+
self.assertEqual(response.root.id, 'ext-card-req-mod')
1327+
modified_card = response.root.result
1328+
self.assertEqual(modified_card.name, 'Modified Card')
1329+
self.assertEqual(modified_card.description, 'Modified for context: bar')
1330+
self.assertEqual(modified_card.version, '1.0')
1331+
1332+
async def test_get_authenticated_extended_card_with_modifier_sync(
1333+
self,
1334+
) -> None:
1335+
"""Test successful retrieval of a synchronously dynamically modified extended agent card."""
1336+
# Arrange
1337+
mock_request_handler = AsyncMock(spec=DefaultRequestHandler)
1338+
mock_base_card = AgentCard(
1339+
name='Base Card',
1340+
description='Base details',
1341+
url='http://agent.example.com/api',
1342+
version='1.0',
1343+
capabilities=AgentCapabilities(),
1344+
default_input_modes=['text/plain'],
1345+
default_output_modes=['application/json'],
1346+
skills=[],
1347+
)
1348+
12981349
def modifier(card: AgentCard, context: ServerCallContext) -> AgentCard:
12991350
modified_card = card.model_copy(deep=True)
13001351
modified_card.name = 'Modified Card'

0 commit comments

Comments
 (0)