Skip to content

Commit 390b763

Browse files
guglielmo-sanholtskinnerishymko
authored
fix: gRPC metadata header casing and invocation_metadata() call (#676)
This PR fixes two issues related to the grpc support: 1. The extension headers key in the metadata are supposed to be lower key ([ref](https://github.com/grpc/grpc/blob/ce463633548e47c05343bc1243f9ea95b700a908/src/core/lib/surface/validate_metadata.cc#L99)) 2. Invokes correctly the `servicer_context.invocation_metadata` as a [method](https://grpc.github.io/grpc/python/grpc_asyncio.html#grpc.aio.ServicerContext.invocation_metadata) instead of as a property For each of these issues was respectively opened a community PR (#635 , #673). To speed up the integration of these changes, this new PR has been created to integrate both changes at the same time. Fixes #656 --------- Co-authored-by: Holt Skinner <13262395+holtskinner@users.noreply.github.com> Co-authored-by: Ivan Shymko <ishymko@google.com>
1 parent 2a73205 commit 390b763

File tree

5 files changed

+44
-35
lines changed

5 files changed

+44
-35
lines changed

src/a2a/client/transports/grpc.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@ def _get_grpc_metadata(
6464
extensions: list[str] | None = None,
6565
) -> list[tuple[str, str]] | None:
6666
"""Creates gRPC metadata for extensions."""
67-
if extensions is not None:
68-
return [(HTTP_EXTENSION_HEADER, ','.join(extensions))]
69-
if self.extensions is not None:
70-
return [(HTTP_EXTENSION_HEADER, ','.join(self.extensions))]
67+
extensions_to_use = extensions or self.extensions
68+
if extensions_to_use:
69+
return [
70+
(HTTP_EXTENSION_HEADER.lower(), ','.join(extensions_to_use))
71+
]
7172
return None
7273

7374
@classmethod

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44

55
from abc import ABC, abstractmethod
66
from collections.abc import AsyncIterable, Awaitable, Sequence
7+
from typing import TYPE_CHECKING
78

89

910
try:
1011
import grpc
1112
import grpc.aio
1213

14+
if TYPE_CHECKING:
15+
from grpc.aio._typing import MetadataType
1316
from grpc.aio import Metadata
1417
except ImportError as e:
1518
raise ImportError(
@@ -53,12 +56,12 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
5356
def _get_metadata_value(
5457
context: grpc.aio.ServicerContext, key: str
5558
) -> list[str]:
56-
md = context.invocation_metadata
59+
md: MetadataType | None = context.invocation_metadata()
5760
raw_values: list[str | bytes] = []
61+
lower_key = key.lower()
5862
if isinstance(md, Metadata):
59-
raw_values = md.get_all(key)
63+
raw_values = md.get_all(lower_key)
6064
elif isinstance(md, Sequence):
61-
lower_key = key.lower()
6265
raw_values = [e for (k, e) in md if k.lower() == lower_key]
6366
return [e if isinstance(e, str) else e.decode('utf-8') for e in raw_values]
6467

@@ -417,7 +420,7 @@ def _set_extension_metadata(
417420
if server_context.activated_extensions:
418421
context.set_trailing_metadata(
419422
[
420-
(HTTP_EXTENSION_HEADER, e)
423+
(HTTP_EXTENSION_HEADER.lower(), e)
421424
for e in sorted(server_context.activated_extensions)
422425
]
423426
)

tests/client/transports/test_grpc_client.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ async def test_send_message_task_response(
202202
_, kwargs = mock_grpc_stub.SendMessage.call_args
203203
assert kwargs['metadata'] == [
204204
(
205-
HTTP_EXTENSION_HEADER,
205+
HTTP_EXTENSION_HEADER.lower(),
206206
'https://example.com/test-ext/v3',
207207
)
208208
]
@@ -228,7 +228,7 @@ async def test_send_message_message_response(
228228
_, kwargs = mock_grpc_stub.SendMessage.call_args
229229
assert kwargs['metadata'] == [
230230
(
231-
HTTP_EXTENSION_HEADER,
231+
HTTP_EXTENSION_HEADER.lower(),
232232
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
233233
)
234234
]
@@ -283,7 +283,7 @@ async def test_send_message_streaming( # noqa: PLR0913
283283
_, kwargs = mock_grpc_stub.SendStreamingMessage.call_args
284284
assert kwargs['metadata'] == [
285285
(
286-
HTTP_EXTENSION_HEADER,
286+
HTTP_EXTENSION_HEADER.lower(),
287287
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
288288
)
289289
]
@@ -313,7 +313,7 @@ async def test_get_task(
313313
),
314314
metadata=[
315315
(
316-
HTTP_EXTENSION_HEADER,
316+
HTTP_EXTENSION_HEADER.lower(),
317317
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
318318
)
319319
],
@@ -338,7 +338,7 @@ async def test_get_task_with_history(
338338
),
339339
metadata=[
340340
(
341-
HTTP_EXTENSION_HEADER,
341+
HTTP_EXTENSION_HEADER.lower(),
342342
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
343343
)
344344
],
@@ -363,7 +363,9 @@ async def test_cancel_task(
363363

364364
mock_grpc_stub.CancelTask.assert_awaited_once_with(
365365
a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}'),
366-
metadata=[(HTTP_EXTENSION_HEADER, 'https://example.com/test-ext/v3')],
366+
metadata=[
367+
(HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v3')
368+
],
367369
)
368370
assert response.status.state == TaskState.canceled
369371

@@ -395,7 +397,7 @@ async def test_set_task_callback_with_valid_task(
395397
),
396398
metadata=[
397399
(
398-
HTTP_EXTENSION_HEADER,
400+
HTTP_EXTENSION_HEADER.lower(),
399401
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
400402
)
401403
],
@@ -458,7 +460,7 @@ async def test_get_task_callback_with_valid_task(
458460
),
459461
metadata=[
460462
(
461-
HTTP_EXTENSION_HEADER,
463+
HTTP_EXTENSION_HEADER.lower(),
462464
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
463465
)
464466
],
@@ -506,27 +508,27 @@ async def test_get_task_callback_with_invalid_task(
506508
(
507509
['ext1'],
508510
None,
509-
[(HTTP_EXTENSION_HEADER, 'ext1')],
511+
[(HTTP_EXTENSION_HEADER.lower(), 'ext1')],
510512
), # Case 2: Initial, No input
511513
(
512514
None,
513515
['ext2'],
514-
[(HTTP_EXTENSION_HEADER, 'ext2')],
516+
[(HTTP_EXTENSION_HEADER.lower(), 'ext2')],
515517
), # Case 3: No initial, Input
516518
(
517519
['ext1'],
518520
['ext2'],
519-
[(HTTP_EXTENSION_HEADER, 'ext2')],
521+
[(HTTP_EXTENSION_HEADER.lower(), 'ext2')],
520522
), # Case 4: Initial, Input (override)
521523
(
522524
['ext1'],
523525
['ext2', 'ext3'],
524-
[(HTTP_EXTENSION_HEADER, 'ext2,ext3')],
526+
[(HTTP_EXTENSION_HEADER.lower(), 'ext2,ext3')],
525527
), # Case 5: Initial, Multiple inputs (override)
526528
(
527529
['ext1', 'ext2'],
528530
['ext3'],
529-
[(HTTP_EXTENSION_HEADER, 'ext3')],
531+
[(HTTP_EXTENSION_HEADER.lower(), 'ext3')],
530532
), # Case 6: Multiple initial, Single input (override)
531533
],
532534
)

tests/integration/test_client_server_integration.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,21 +381,24 @@ def channel_factory(address: str) -> Channel:
381381
parts=[Part(root=TextPart(text='Hello, gRPC blocking test!'))],
382382
)
383383
params = MessageSendParams(message=message_to_send)
384+
extensions = ['ext-1', 'ext-2']
384385

385-
result = await transport.send_message(request=params)
386+
result = await transport.send_message(request=params, extensions=extensions)
386387

387388
assert result.id == TASK_FROM_BLOCKING.id
388389
assert result.context_id == TASK_FROM_BLOCKING.context_id
389390

390391
handler.on_message_send.assert_awaited_once()
391392
call_args, _ = handler.on_message_send.call_args
392393
received_params: MessageSendParams = call_args[0]
394+
received_context = call_args[1]
393395

394396
assert received_params.message.message_id == message_to_send.message_id
395397
assert (
396398
received_params.message.parts[0].root.text
397399
== message_to_send.parts[0].root.text
398400
)
401+
assert received_context.requested_extensions == set(extensions)
399402

400403
await transport.close()
401404

tests/server/request_handlers/test_grpc_handler.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -349,9 +349,9 @@ async def test_send_message_with_extensions(
349349
mock_request_handler: AsyncMock,
350350
mock_grpc_context: AsyncMock,
351351
) -> None:
352-
mock_grpc_context.invocation_metadata = grpc.aio.Metadata(
353-
(HTTP_EXTENSION_HEADER, 'foo'),
354-
(HTTP_EXTENSION_HEADER, 'bar'),
352+
mock_grpc_context.invocation_metadata.return_value = grpc.aio.Metadata(
353+
(HTTP_EXTENSION_HEADER.lower(), 'foo'),
354+
(HTTP_EXTENSION_HEADER.lower(), 'bar'),
355355
)
356356

357357
def side_effect(request, context: ServerCallContext):
@@ -379,8 +379,8 @@ def side_effect(request, context: ServerCallContext):
379379
mock_grpc_context.set_trailing_metadata.call_args.args[0]
380380
)
381381
assert set(called_metadata) == {
382-
(HTTP_EXTENSION_HEADER, 'foo'),
383-
(HTTP_EXTENSION_HEADER, 'baz'),
382+
(HTTP_EXTENSION_HEADER.lower(), 'foo'),
383+
(HTTP_EXTENSION_HEADER.lower(), 'baz'),
384384
}
385385

386386
async def test_send_message_with_comma_separated_extensions(
@@ -389,9 +389,9 @@ async def test_send_message_with_comma_separated_extensions(
389389
mock_request_handler: AsyncMock,
390390
mock_grpc_context: AsyncMock,
391391
) -> None:
392-
mock_grpc_context.invocation_metadata = grpc.aio.Metadata(
393-
(HTTP_EXTENSION_HEADER, 'foo ,, bar,'),
394-
(HTTP_EXTENSION_HEADER, 'baz , bar'),
392+
mock_grpc_context.invocation_metadata.return_value = grpc.aio.Metadata(
393+
(HTTP_EXTENSION_HEADER.lower(), 'foo ,, bar,'),
394+
(HTTP_EXTENSION_HEADER.lower(), 'baz , bar'),
395395
)
396396
mock_request_handler.on_message_send.return_value = types.Message(
397397
message_id='1',
@@ -414,9 +414,9 @@ async def test_send_streaming_message_with_extensions(
414414
mock_request_handler: AsyncMock,
415415
mock_grpc_context: AsyncMock,
416416
) -> None:
417-
mock_grpc_context.invocation_metadata = grpc.aio.Metadata(
418-
(HTTP_EXTENSION_HEADER, 'foo'),
419-
(HTTP_EXTENSION_HEADER, 'bar'),
417+
mock_grpc_context.invocation_metadata.return_value = grpc.aio.Metadata(
418+
(HTTP_EXTENSION_HEADER.lower(), 'foo'),
419+
(HTTP_EXTENSION_HEADER.lower(), 'bar'),
420420
)
421421

422422
async def side_effect(request, context: ServerCallContext):
@@ -450,6 +450,6 @@ async def side_effect(request, context: ServerCallContext):
450450
mock_grpc_context.set_trailing_metadata.call_args.args[0]
451451
)
452452
assert set(called_metadata) == {
453-
(HTTP_EXTENSION_HEADER, 'foo'),
454-
(HTTP_EXTENSION_HEADER, 'baz'),
453+
(HTTP_EXTENSION_HEADER.lower(), 'foo'),
454+
(HTTP_EXTENSION_HEADER.lower(), 'baz'),
455455
}

0 commit comments

Comments
 (0)