Skip to content

Commit e259b35

Browse files
committed
test: Add test for interrupted signal preservation with pending text
Add test_receive_interrupted_with_pending_text_preserves_flag to verify: - interrupted flag is preserved when flushing pending text - grounding_metadata is carried through to the flushed response - accumulated text is properly merged before interruption Addresses MEDIUM priority review comment about missing test coverage for edge cases.
1 parent 2024ead commit e259b35

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed

tests/unittests/models/test_gemini_llm_connection.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,3 +991,92 @@ async def mock_receive_generator():
991991
assert tool_call_response is not None
992992
# The grounding_metadata should be carried over to tool_call
993993
assert tool_call_response.grounding_metadata == mock_grounding_metadata
994+
995+
996+
@pytest.mark.asyncio
997+
async def test_receive_interrupted_with_pending_text_preserves_flag(
998+
gemini_connection, mock_gemini_session
999+
):
1000+
"""Test that interrupted flag is preserved when flushing pending text."""
1001+
mock_grounding_metadata = types.GroundingMetadata(
1002+
retrieval_queries=['test query'],
1003+
)
1004+
1005+
# First message with text content and grounding
1006+
mock_content1 = types.Content(
1007+
role='model', parts=[types.Part.from_text(text='partial')]
1008+
)
1009+
mock_server_content1 = mock.Mock()
1010+
mock_server_content1.model_turn = mock_content1
1011+
mock_server_content1.interrupted = False
1012+
mock_server_content1.input_transcription = None
1013+
mock_server_content1.output_transcription = None
1014+
mock_server_content1.turn_complete = False
1015+
mock_server_content1.generation_complete = False
1016+
mock_server_content1.grounding_metadata = mock_grounding_metadata
1017+
1018+
message1 = mock.Mock()
1019+
message1.usage_metadata = None
1020+
message1.server_content = mock_server_content1
1021+
message1.tool_call = None
1022+
message1.session_resumption_update = None
1023+
1024+
# Second message with more text
1025+
mock_content2 = types.Content(
1026+
role='model', parts=[types.Part.from_text(text=' text')]
1027+
)
1028+
mock_server_content2 = mock.Mock()
1029+
mock_server_content2.model_turn = mock_content2
1030+
mock_server_content2.interrupted = False
1031+
mock_server_content2.input_transcription = None
1032+
mock_server_content2.output_transcription = None
1033+
mock_server_content2.turn_complete = False
1034+
mock_server_content2.generation_complete = False
1035+
mock_server_content2.grounding_metadata = None
1036+
1037+
message2 = mock.Mock()
1038+
message2.usage_metadata = None
1039+
message2.server_content = mock_server_content2
1040+
message2.tool_call = None
1041+
message2.session_resumption_update = None
1042+
1043+
# Third message with interrupted signal
1044+
mock_server_content3 = mock.Mock()
1045+
mock_server_content3.model_turn = None
1046+
mock_server_content3.interrupted = True
1047+
mock_server_content3.input_transcription = None
1048+
mock_server_content3.output_transcription = None
1049+
mock_server_content3.turn_complete = False
1050+
mock_server_content3.generation_complete = False
1051+
mock_server_content3.grounding_metadata = None
1052+
1053+
message3 = mock.Mock()
1054+
message3.usage_metadata = None
1055+
message3.server_content = mock_server_content3
1056+
message3.tool_call = None
1057+
message3.session_resumption_update = None
1058+
1059+
async def mock_receive_generator():
1060+
yield message1
1061+
yield message2
1062+
yield message3
1063+
1064+
receive_mock = mock.Mock(return_value=mock_receive_generator())
1065+
mock_gemini_session.receive = receive_mock
1066+
1067+
responses = [resp async for resp in gemini_connection.receive()]
1068+
1069+
# Find the full text response that should have been flushed with interrupted=True
1070+
full_text_responses = [
1071+
r for r in responses if r.content and not r.partial and r.interrupted
1072+
]
1073+
assert (
1074+
len(full_text_responses) > 0
1075+
), 'Should have interrupted full text response'
1076+
1077+
# The full text response should have the accumulated text
1078+
assert full_text_responses[0].content.parts[0].text == 'partial text'
1079+
# And should carry the grounding_metadata
1080+
assert full_text_responses[0].grounding_metadata == mock_grounding_metadata
1081+
# And should have interrupted=True
1082+
assert full_text_responses[0].interrupted is True

0 commit comments

Comments
 (0)