@@ -991,3 +991,92 @@ async def mock_receive_generator():
991991 assert tool_call_response is not None
992992 # The grounding_metadata should be carried over to tool_call
993993 assert tool_call_response .grounding_metadata == mock_grounding_metadata
994+
995+
996+ @pytest .mark .asyncio
997+ async def test_receive_interrupted_with_pending_text_preserves_flag (
998+ gemini_connection , mock_gemini_session
999+ ):
1000+ """Test that interrupted flag is preserved when flushing pending text."""
1001+ mock_grounding_metadata = types .GroundingMetadata (
1002+ retrieval_queries = ['test query' ],
1003+ )
1004+
1005+ # First message with text content and grounding
1006+ mock_content1 = types .Content (
1007+ role = 'model' , parts = [types .Part .from_text (text = 'partial' )]
1008+ )
1009+ mock_server_content1 = mock .Mock ()
1010+ mock_server_content1 .model_turn = mock_content1
1011+ mock_server_content1 .interrupted = False
1012+ mock_server_content1 .input_transcription = None
1013+ mock_server_content1 .output_transcription = None
1014+ mock_server_content1 .turn_complete = False
1015+ mock_server_content1 .generation_complete = False
1016+ mock_server_content1 .grounding_metadata = mock_grounding_metadata
1017+
1018+ message1 = mock .Mock ()
1019+ message1 .usage_metadata = None
1020+ message1 .server_content = mock_server_content1
1021+ message1 .tool_call = None
1022+ message1 .session_resumption_update = None
1023+
1024+ # Second message with more text
1025+ mock_content2 = types .Content (
1026+ role = 'model' , parts = [types .Part .from_text (text = ' text' )]
1027+ )
1028+ mock_server_content2 = mock .Mock ()
1029+ mock_server_content2 .model_turn = mock_content2
1030+ mock_server_content2 .interrupted = False
1031+ mock_server_content2 .input_transcription = None
1032+ mock_server_content2 .output_transcription = None
1033+ mock_server_content2 .turn_complete = False
1034+ mock_server_content2 .generation_complete = False
1035+ mock_server_content2 .grounding_metadata = None
1036+
1037+ message2 = mock .Mock ()
1038+ message2 .usage_metadata = None
1039+ message2 .server_content = mock_server_content2
1040+ message2 .tool_call = None
1041+ message2 .session_resumption_update = None
1042+
1043+ # Third message with interrupted signal
1044+ mock_server_content3 = mock .Mock ()
1045+ mock_server_content3 .model_turn = None
1046+ mock_server_content3 .interrupted = True
1047+ mock_server_content3 .input_transcription = None
1048+ mock_server_content3 .output_transcription = None
1049+ mock_server_content3 .turn_complete = False
1050+ mock_server_content3 .generation_complete = False
1051+ mock_server_content3 .grounding_metadata = None
1052+
1053+ message3 = mock .Mock ()
1054+ message3 .usage_metadata = None
1055+ message3 .server_content = mock_server_content3
1056+ message3 .tool_call = None
1057+ message3 .session_resumption_update = None
1058+
1059+ async def mock_receive_generator ():
1060+ yield message1
1061+ yield message2
1062+ yield message3
1063+
1064+ receive_mock = mock .Mock (return_value = mock_receive_generator ())
1065+ mock_gemini_session .receive = receive_mock
1066+
1067+ responses = [resp async for resp in gemini_connection .receive ()]
1068+
1069+ # Find the full text response that should have been flushed with interrupted=True
1070+ full_text_responses = [
1071+ r for r in responses if r .content and not r .partial and r .interrupted
1072+ ]
1073+ assert (
1074+ len (full_text_responses ) > 0
1075+ ), 'Should have interrupted full text response'
1076+
1077+ # The full text response should have the accumulated text
1078+ assert full_text_responses [0 ].content .parts [0 ].text == 'partial text'
1079+ # And should carry the grounding_metadata
1080+ assert full_text_responses [0 ].grounding_metadata == mock_grounding_metadata
1081+ # And should have interrupted=True
1082+ assert full_text_responses [0 ].interrupted is True
0 commit comments