diff --git a/src/agentscope/tts/_dashscope_realtime_tts_model.py b/src/agentscope/tts/_dashscope_realtime_tts_model.py index df02bdd2b0..9b5b69d17b 100644 --- a/src/agentscope/tts/_dashscope_realtime_tts_model.py +++ b/src/agentscope/tts/_dashscope_realtime_tts_model.py @@ -1,11 +1,16 @@ # -*- coding: utf-8 -*- +# pylint: disable=too-many-branches, too-many-statements """DashScope Realtime TTS model implementation.""" +import asyncio import threading from typing import Any, Literal, TYPE_CHECKING, AsyncGenerator +from websocket import WebSocketConnectionClosedException + from ._tts_base import TTSModelBase from ._tts_response import TTSResponse +from .._logging import logger from ..message import Msg, AudioBlock, Base64Source from ..types import JSONSerializableObject @@ -82,6 +87,26 @@ def on_event(self, response: dict[str, Any]) -> None: traceback.print_exc() self.finish_event.set() + def on_close(self, close_status_code: int, close_msg: str) -> None: + """Called when the WebSocket connection is closed. + + Args: + close_status_code (`int`): + The close status code. + close_msg (`str`): + The close message. + """ + # Unblock waiting operations to prevent deadlock + self.finish_event.set() + self.chunk_event.set() + + if close_status_code: + logger.warning( + "TTS WebSocket connection closed with code %s: %s", + close_status_code, + close_msg, + ) + async def get_audio_data(self, block: bool) -> TTSResponse: """Get the current accumulated audio data as base64 string so far. @@ -164,6 +189,10 @@ async def _reset(self) -> None: self.chunk_event.clear() self._audio_data = "" + def has_audio_data(self) -> bool: + """Check if audio data has been received.""" + return bool(self._audio_data) + return _DashScopeRealtimeTTSCallback @@ -196,6 +225,8 @@ def __init__( cold_start_words: int | None = None, client_kwargs: dict[str, JSONSerializableObject] | None = None, generate_kwargs: dict[str, JSONSerializableObject] | None = None, + max_retries: int = 3, + retry_delay: float = 5.0, ) -> None: """Initialize the DashScope TTS model by specifying the model, voice, and other parameters. @@ -240,6 +271,10 @@ def __init__( optional): The extra keyword arguments used in DashScope realtime tts API generation. + max_retries (`int`, defaults to 3): + The maximum number of retry attempts when TTS synthesis fails. + retry_delay (`float`, defaults to 5.0): + The delay in seconds before retrying. Uses exponential backoff. """ super().__init__(model_name=model_name, stream=stream) @@ -255,6 +290,8 @@ def __init__( self.cold_start_words = cold_start_words self.client_kwargs = client_kwargs or {} self.generate_kwargs = generate_kwargs or {} + self.max_retries = max_retries + self.retry_delay = retry_delay # Initialize TTS client # Save callback reference (for DashScope SDK) @@ -298,9 +335,29 @@ async def close(self) -> None: self._connected = False - self._tts_client.finish() self._tts_client.close() + async def _reconnect(self) -> None: + """Reconnect to TTS service by recreating the client.""" + from dashscope.audio.qwen_tts_realtime import QwenTtsRealtime + + try: + self._tts_client.close() + except Exception: + pass + + self._dashscope_callback = _get_qwen_tts_realtime_callback_class()() + self._tts_client = QwenTtsRealtime( + model=self.model_name, + callback=self._dashscope_callback, + **self.client_kwargs, + ) + self._connected = False + self._first_send = True + self._current_msg_id = None + self._current_prefix = "" + await self.connect() + async def push( self, msg: Msg, @@ -362,7 +419,12 @@ async def push( delta_to_send = text.removeprefix(self._current_prefix) if delta_to_send: - self._tts_client.append_text(delta_to_send) + try: + self._tts_client.append_text(delta_to_send) + except WebSocketConnectionClosedException: + # Connection closed, return empty response + # synthesize() will handle retry + return TTSResponse(content=None) # Record sent prefix self._current_prefix += delta_to_send @@ -399,7 +461,11 @@ async def synthesize( "TTS model is not connected. Call `connect()` first.", ) - if self._current_msg_id is not None and self._current_msg_id != msg.id: + if ( + self._current_msg_id is not None + and msg + and self._current_msg_id != msg.id + ): raise RuntimeError( "DashScopeRealtimeTTSModel can only handle one streaming " "input request at a time. Please ensure that all chunks " @@ -416,19 +482,85 @@ async def synthesize( self._current_prefix, ) - # Determine if we should send text based on cold start settings only - # for the first input chunk and not the last chunk - if delta_to_send: - self._tts_client.append_text(delta_to_send) + full_text = (msg.get_text_content() or "") if msg else "" - # To keep correct prefix tracking - self._current_prefix += delta_to_send - self._first_send = False + # Synthesize with retry - if we have text but get no audio, retry + delay = self.retry_delay - # We need to block until synthesis is complete to get all audio - self._tts_client.commit() - self._tts_client.finish() + for attempt in range(self.max_retries): + try: + # Send remaining text if any + if delta_to_send: + self._tts_client.append_text(delta_to_send) + self._current_prefix += delta_to_send + self._first_send = False + + # Commit and finish + self._tts_client.commit() + self._tts_client.finish() + + # Wait for synthesis to complete + self._dashscope_callback.finish_event.wait() + + # Check if we got audio (only retry if we had text but no + # audio) + has_audio = self._dashscope_callback.has_audio_data() + if full_text and not has_audio: + if attempt < self.max_retries - 1: + logger.warning( + "TTS: no audio received, retrying (%d/%d) in " + "%.1fs...", + attempt + 1, + self.max_retries, + delay, + ) + await asyncio.sleep(delay) + await self._reconnect() + # After reconnect, need to resend full text + delta_to_send = full_text + delay *= 2 + continue + logger.error( + "TTS: no audio after %d attempts.", + self.max_retries, + ) + # Reset state before raising + self._current_msg_id = None + self._first_send = True + self._current_prefix = "" + raise RuntimeError( + f"TTS synthesis failed: no audio after" + f" {self.max_retries} attempts", + ) + + # Success + break + + except WebSocketConnectionClosedException: + if attempt < self.max_retries - 1: + logger.warning( + "TTS failed, retrying (%d/%d) in %.1fs...", + attempt + 1, + self.max_retries, + delay, + ) + await asyncio.sleep(delay) + await self._reconnect() + # After reconnect, need to resend full text + delta_to_send = full_text + delay *= 2 + else: + logger.error( + "TTS failed after %d attempts.", + self.max_retries, + ) + # Reset state before raising + self._current_msg_id = None + self._first_send = True + self._current_prefix = "" + raise + # Get result if self.stream: # Return an async generator for audio chunks res = self._dashscope_callback.get_audio_chunk() diff --git a/tests/tts_dashscope_test.py b/tests/tts_dashscope_test.py index 8a37f58bd7..dc1f15a6c8 100644 --- a/tests/tts_dashscope_test.py +++ b/tests/tts_dashscope_test.py @@ -132,6 +132,13 @@ async def test_synthesize_non_streaming(self) -> None: api_key=self.api_key, stream=False, ) as model: + # Mock finish_event to not block + model._dashscope_callback.finish_event = Mock() + model._dashscope_callback.finish_event.wait = Mock() + # Mock has_audio_data to return True (skip retry) + model._dashscope_callback.has_audio_data = Mock( + return_value=True, + ) model._dashscope_callback.get_audio_data = AsyncMock( return_value=TTSResponse( content=AudioBlock( @@ -169,6 +176,13 @@ async def test_synthesize_streaming(self) -> None: api_key=self.api_key, stream=True, ) as model: + # Mock finish_event to not block + model._dashscope_callback.finish_event = Mock() + model._dashscope_callback.finish_event.wait = Mock() + # Mock has_audio_data to return True (skip retry) + model._dashscope_callback.has_audio_data = Mock( + return_value=True, + ) async def mock_generator() -> AsyncGenerator[ TTSResponse,