Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 145 additions & 13 deletions src/agentscope/tts/_dashscope_realtime_tts_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches, too-many-statements
"""DashScope Realtime TTS model implementation."""

import asyncio
import threading
from typing import Any, Literal, TYPE_CHECKING, AsyncGenerator

from websocket import WebSocketConnectionClosedException

from ._tts_base import TTSModelBase
from ._tts_response import TTSResponse
from .._logging import logger
from ..message import Msg, AudioBlock, Base64Source
from ..types import JSONSerializableObject

Expand Down Expand Up @@ -82,6 +87,26 @@ def on_event(self, response: dict[str, Any]) -> None:
traceback.print_exc()
self.finish_event.set()

def on_close(self, close_status_code: int, close_msg: str) -> None:
"""Called when the WebSocket connection is closed.

Args:
close_status_code (`int`):
The close status code.
close_msg (`str`):
The close message.
"""
# Unblock waiting operations to prevent deadlock
self.finish_event.set()
self.chunk_event.set()

if close_status_code:
logger.warning(
"TTS WebSocket connection closed with code %s: %s",
close_status_code,
close_msg,
)

async def get_audio_data(self, block: bool) -> TTSResponse:
"""Get the current accumulated audio data as base64 string so far.

Expand Down Expand Up @@ -164,6 +189,10 @@ async def _reset(self) -> None:
self.chunk_event.clear()
self._audio_data = ""

def has_audio_data(self) -> bool:
"""Check if audio data has been received."""
return bool(self._audio_data)

return _DashScopeRealtimeTTSCallback


Expand Down Expand Up @@ -196,6 +225,8 @@ def __init__(
cold_start_words: int | None = None,
client_kwargs: dict[str, JSONSerializableObject] | None = None,
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
max_retries: int = 3,
retry_delay: float = 5.0,
) -> None:
"""Initialize the DashScope TTS model by specifying the model, voice,
and other parameters.
Expand Down Expand Up @@ -240,6 +271,10 @@ def __init__(
optional):
The extra keyword arguments used in DashScope realtime tts API
generation.
max_retries (`int`, defaults to 3):
The maximum number of retry attempts when TTS synthesis fails.
retry_delay (`float`, defaults to 5.0):
The delay in seconds before retrying. Uses exponential backoff.
"""
super().__init__(model_name=model_name, stream=stream)

Expand All @@ -255,6 +290,8 @@ def __init__(
self.cold_start_words = cold_start_words
self.client_kwargs = client_kwargs or {}
self.generate_kwargs = generate_kwargs or {}
self.max_retries = max_retries
self.retry_delay = retry_delay

# Initialize TTS client
# Save callback reference (for DashScope SDK)
Expand Down Expand Up @@ -298,9 +335,29 @@ async def close(self) -> None:

self._connected = False

self._tts_client.finish()
self._tts_client.close()

async def _reconnect(self) -> None:
"""Reconnect to TTS service by recreating the client."""
from dashscope.audio.qwen_tts_realtime import QwenTtsRealtime

try:
self._tts_client.close()
except Exception:
pass

self._dashscope_callback = _get_qwen_tts_realtime_callback_class()()
self._tts_client = QwenTtsRealtime(
model=self.model_name,
callback=self._dashscope_callback,
**self.client_kwargs,
)
self._connected = False
self._first_send = True
self._current_msg_id = None
self._current_prefix = ""
await self.connect()

async def push(
self,
msg: Msg,
Expand Down Expand Up @@ -362,7 +419,12 @@ async def push(
delta_to_send = text.removeprefix(self._current_prefix)

if delta_to_send:
self._tts_client.append_text(delta_to_send)
try:
self._tts_client.append_text(delta_to_send)
except WebSocketConnectionClosedException:
# Connection closed, return empty response
# synthesize() will handle retry
return TTSResponse(content=None)

# Record sent prefix
self._current_prefix += delta_to_send
Expand Down Expand Up @@ -399,7 +461,11 @@ async def synthesize(
"TTS model is not connected. Call `connect()` first.",
)

if self._current_msg_id is not None and self._current_msg_id != msg.id:
if (
self._current_msg_id is not None
and msg
and self._current_msg_id != msg.id
):
raise RuntimeError(
"DashScopeRealtimeTTSModel can only handle one streaming "
"input request at a time. Please ensure that all chunks "
Expand All @@ -416,19 +482,85 @@ async def synthesize(
self._current_prefix,
)

# Determine if we should send text based on cold start settings only
# for the first input chunk and not the last chunk
if delta_to_send:
self._tts_client.append_text(delta_to_send)
full_text = (msg.get_text_content() or "") if msg else ""

# To keep correct prefix tracking
self._current_prefix += delta_to_send
self._first_send = False
# Synthesize with retry - if we have text but get no audio, retry
delay = self.retry_delay

# We need to block until synthesis is complete to get all audio
self._tts_client.commit()
self._tts_client.finish()
for attempt in range(self.max_retries):
try:
# Send remaining text if any
if delta_to_send:
self._tts_client.append_text(delta_to_send)
self._current_prefix += delta_to_send
self._first_send = False

# Commit and finish
self._tts_client.commit()
self._tts_client.finish()

# Wait for synthesis to complete
self._dashscope_callback.finish_event.wait()

# Check if we got audio (only retry if we had text but no
# audio)
has_audio = self._dashscope_callback.has_audio_data()
if full_text and not has_audio:
if attempt < self.max_retries - 1:
logger.warning(
"TTS: no audio received, retrying (%d/%d) in "
"%.1fs...",
attempt + 1,
self.max_retries,
delay,
)
await asyncio.sleep(delay)
await self._reconnect()
# After reconnect, need to resend full text
delta_to_send = full_text
delay *= 2
continue
logger.error(
"TTS: no audio after %d attempts.",
self.max_retries,
)
# Reset state before raising
self._current_msg_id = None
self._first_send = True
self._current_prefix = ""
raise RuntimeError(
f"TTS synthesis failed: no audio after"
f" {self.max_retries} attempts",
)

# Success
break

except WebSocketConnectionClosedException:
if attempt < self.max_retries - 1:
logger.warning(
"TTS failed, retrying (%d/%d) in %.1fs...",
attempt + 1,
self.max_retries,
delay,
)
await asyncio.sleep(delay)
await self._reconnect()
# After reconnect, need to resend full text
delta_to_send = full_text
delay *= 2
else:
logger.error(
"TTS failed after %d attempts.",
self.max_retries,
)
# Reset state before raising
self._current_msg_id = None
self._first_send = True
self._current_prefix = ""
raise

# Get result
if self.stream:
# Return an async generator for audio chunks
res = self._dashscope_callback.get_audio_chunk()
Expand Down
14 changes: 14 additions & 0 deletions tests/tts_dashscope_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ async def test_synthesize_non_streaming(self) -> None:
api_key=self.api_key,
stream=False,
) as model:
# Mock finish_event to not block
model._dashscope_callback.finish_event = Mock()
model._dashscope_callback.finish_event.wait = Mock()
# Mock has_audio_data to return True (skip retry)
model._dashscope_callback.has_audio_data = Mock(
return_value=True,
)
model._dashscope_callback.get_audio_data = AsyncMock(
return_value=TTSResponse(
content=AudioBlock(
Expand Down Expand Up @@ -169,6 +176,13 @@ async def test_synthesize_streaming(self) -> None:
api_key=self.api_key,
stream=True,
) as model:
# Mock finish_event to not block
model._dashscope_callback.finish_event = Mock()
model._dashscope_callback.finish_event.wait = Mock()
# Mock has_audio_data to return True (skip retry)
model._dashscope_callback.has_audio_data = Mock(
return_value=True,
)

async def mock_generator() -> AsyncGenerator[
TTSResponse,
Expand Down