diff --git a/src/edge_tts/communicate.py b/src/edge_tts/communicate.py index c53f007..36c5dff 100644 --- a/src/edge_tts/communicate.py +++ b/src/edge_tts/communicate.py @@ -5,7 +5,6 @@ import asyncio import concurrent.futures import json -import re import ssl import time import uuid @@ -28,37 +27,38 @@ import aiohttp import certifi -from edge_tts.exceptions import ( +from .constants import WSS_URL +from .exceptions import ( NoAudioReceived, UnexpectedResponse, UnknownResponse, WebSocketError, ) - -from .constants import WSS_URL +from .models import TTSConfig -def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[bytes, bytes], bytes]: +def get_headers_and_data( + data: bytes, header_length: int +) -> Tuple[Dict[bytes, bytes], bytes]: """ Returns the headers and data from the given data. Args: - data (str or bytes): The data to be parsed. + data (bytes): The data to be parsed. + header_length (int): The length of the header. Returns: tuple: The headers and data to be used in the request. """ - if isinstance(data, str): - data = data.encode("utf-8") if not isinstance(data, bytes): - raise TypeError("data must be str or bytes") + raise TypeError("data must be bytes") headers = {} - for line in data[: data.find(b"\r\n\r\n")].split(b"\r\n"): + for line in data[:header_length].split(b"\r\n"): key, value = line.split(b":", 1) headers[key] = value - return headers, data[data.find(b"\r\n\r\n") + 4 :] + return headers, data[header_length + 2 :] def remove_incompatible_characters(string: Union[str, bytes]) -> str: @@ -154,24 +154,32 @@ def split_text_by_byte_length( yield new_text -def mkssml( - text: Union[str, bytes], voice: str, rate: str, volume: str, pitch: str -) -> str: +def mkssml(tc: TTSConfig, escaped_text: Union[str, bytes]) -> str: """ Creates a SSML string from the given parameters. + Args: + tc (TTSConfig): The TTS configuration. + escaped_text (str or bytes): The escaped text. If bytes, it must be UTF-8 encoded. + Returns: str: The SSML string. """ - if isinstance(text, bytes): - text = text.decode("utf-8") - ssml = ( + # If the text is bytes, convert it to a string. + if isinstance(escaped_text, bytes): + escaped_text = escaped_text.decode("utf-8") + + # Return the SSML string. + return ( "" - f"" - f"{text}" + f"" + f"" + f"{escaped_text}" + "" + "" + "" ) - return ssml def date_to_string() -> str: @@ -207,7 +215,7 @@ def ssml_headers_plus_data(request_id: str, timestamp: str, ssml: str) -> str: ) -def calc_max_mesg_size(voice: str, rate: str, volume: str, pitch: str) -> int: +def calc_max_mesg_size(tts_config: TTSConfig) -> int: """Calculates the maximum message size for the given voice, rate, and volume. Returns: @@ -219,7 +227,7 @@ def calc_max_mesg_size(voice: str, rate: str, volume: str, pitch: str) -> int: ssml_headers_plus_data( connect_id(), date_to_string(), - mkssml("", voice, rate, volume, pitch), + mkssml(tts_config, ""), ) ) + 50 # margin of error @@ -232,25 +240,6 @@ class Communicate: Class for communicating with the service. """ - @staticmethod - def validate_string_param(param_name: str, param_value: str, pattern: str) -> str: - """ - Validates the given string parameter based on type and pattern. - - Args: - param_name (str): The name of the parameter. - param_value (str): The value of the parameter. - pattern (str): The pattern to validate the parameter against. - - Returns: - str: The validated parameter. - """ - if not isinstance(param_value, str): - raise TypeError(f"{param_name} must be str") - if re.match(pattern, param_value) is None: - raise ValueError(f"Invalid {param_name} '{param_value}'.") - return param_value - def __init__( self, text: str, @@ -269,46 +258,30 @@ def __init__( Raises: ValueError: If the voice is not valid. """ + + # Validate TTS settings and store the TTSConfig object. + self.tts_config = TTSConfig(voice, rate, volume, pitch) + + # Validate the text parameter. if not isinstance(text, str): raise TypeError("text must be str") - self.text: str = text - - # Possible values for voice are: - # - Microsoft Server Speech Text to Speech Voice (cy-GB, NiaNeural) - # - cy-GB-NiaNeural - # - fil-PH-AngeloNeural - # Always send the first variant as that is what Microsoft Edge does. - if not isinstance(voice, str): - raise TypeError("voice must be str") - self.voice: str = voice - match = re.match(r"^([a-z]{2,})-([A-Z]{2,})-(.+Neural)$", voice) - if match is not None: - lang = match.group(1) - region = match.group(2) - name = match.group(3) - if name.find("-") != -1: - region = region + "-" + name[: name.find("-")] - name = name[name.find("-") + 1 :] - self.voice = ( - "Microsoft Server Speech Text to Speech Voice" - + f" ({lang}-{region}, {name})" - ) - self.voice = self.validate_string_param( - "voice", - self.voice, - r"^Microsoft Server Speech Text to Speech Voice \(.+,.+\)$", + # Split the text into multiple strings and store them. + self.texts = split_text_by_byte_length( + escape(remove_incompatible_characters(text)), + calc_max_mesg_size(self.tts_config), ) - self.rate = self.validate_string_param("rate", rate, r"^[+-]\d+%$") - self.volume = self.validate_string_param("volume", volume, r"^[+-]\d+%$") - self.pitch = self.validate_string_param("pitch", pitch, r"^[+-]\d+Hz$") + # Validate the proxy parameter. if proxy is not None and not isinstance(proxy, str): raise TypeError("proxy must be str") self.proxy: Optional[str] = proxy - if not isinstance(connect_timeout, int) or not isinstance(receive_timeout, int): - raise TypeError("connect_timeout and receive_timeout must be int") + # Validate the timeout parameters. + if not isinstance(connect_timeout, int): + raise TypeError("connect_timeout must be int") + if not isinstance(receive_timeout, int): + raise TypeError("receive_timeout must be int") self.session_timeout = aiohttp.ClientTimeout( total=None, connect=None, @@ -316,9 +289,34 @@ def __init__( sock_read=receive_timeout, ) - async def stream(self) -> AsyncGenerator[Dict[str, Any], None]: - """Streams audio and metadata from the service.""" - + # Store current state of TTS. + self.state: Dict[str, Any] = { + "partial_text": None, + "offset_compensation": 0, + "last_duration_offset": 0, + "stream_was_called": False, + } + + def __parse_metadata(self, data: bytes) -> Dict[str, Any]: + for meta_obj in json.loads(data)["Metadata"]: + meta_type = meta_obj["Type"] + if meta_type == "WordBoundary": + current_offset = ( + meta_obj["Data"]["Offset"] + self.state["offset_compensation"] + ) + current_duration = meta_obj["Data"]["Duration"] + return { + "type": meta_type, + "offset": current_offset, + "duration": current_duration, + "text": meta_obj["Data"]["text"]["Text"], + } + if meta_type in ("SessionEnd",): + continue + raise UnknownResponse(f"Unknown metadata type: {meta_type}") + raise UnexpectedResponse("No WordBoundary metadata found") + + async def __stream(self) -> AsyncGenerator[Dict[str, Any], None]: async def send_command_request() -> None: """Sends the request to the service.""" @@ -342,55 +340,25 @@ async def send_command_request() -> None: "}}}}\r\n" ) - async def send_ssml_request() -> bool: + async def send_ssml_request() -> None: """Sends the SSML request to the service.""" - # Get the next string from the generator. - text = next(texts, None) - - # If there are no more strings, return False. - if text is None: - return False - - # Send the request to the service and return True. + # Send the request to the service. await websocket.send_str( ssml_headers_plus_data( connect_id(), date_to_string(), - mkssml(text, self.voice, self.rate, self.volume, self.pitch), + mkssml( + self.tts_config, + self.state["partial_text"], + ), ) ) - return True - - def parse_metadata() -> Dict[str, Any]: - for meta_obj in json.loads(data)["Metadata"]: - meta_type = meta_obj["Type"] - if meta_type == "WordBoundary": - current_offset = meta_obj["Data"]["Offset"] + offset_compensation - current_duration = meta_obj["Data"]["Duration"] - return { - "type": meta_type, - "offset": current_offset, - "duration": current_duration, - "text": meta_obj["Data"]["text"]["Text"], - } - if meta_type in ("SessionEnd",): - continue - raise UnknownResponse(f"Unknown metadata type: {meta_type}") - raise UnexpectedResponse("No WordBoundary metadata found") - - # Split the text into multiple strings if it is too long for the service. - texts = split_text_by_byte_length( - escape(remove_incompatible_characters(self.text)), - calc_max_mesg_size(self.voice, self.rate, self.volume, self.pitch), - ) - - # Keep track of last duration + offset to calculate the offset - # upon word split. - last_duration_offset = 0 - # Current offset compensations. - offset_compensation = 0 + # audio_was_received indicates whether we have received audio data + # from the websocket. This is so we can raise an exception if we + # don't receive any audio data. + audio_was_received = False # Create a new connection to the service. ssl_ctx = ssl.create_default_context(cafile=certifi.where()) @@ -412,11 +380,6 @@ def parse_metadata() -> Dict[str, Any]: }, ssl=ssl_ctx, ) as websocket: - # audio_was_received indicates whether we have received audio data - # from the websocket. This is so we can raise an exception if we - # don't receive any audio data. - audio_was_received = False - # Send the request to the service. await send_command_request() @@ -425,53 +388,91 @@ def parse_metadata() -> Dict[str, Any]: async for received in websocket: if received.type == aiohttp.WSMsgType.TEXT: - parameters, data = get_headers_and_data(received.data) - path = parameters.get(b"Path") + encoded_data: bytes = received.data.encode("utf-8") + parameters, data = get_headers_and_data( + encoded_data, encoded_data.find(b"\r\n\r\n") + ) + + path = parameters.get(b"Path", None) if path == b"audio.metadata": # Parse the metadata and yield it. - parsed_metadata = parse_metadata() + parsed_metadata = self.__parse_metadata(data) yield parsed_metadata # Update the last duration offset for use by the next SSML request. - last_duration_offset = ( + self.state["last_duration_offset"] = ( parsed_metadata["offset"] + parsed_metadata["duration"] ) elif path == b"turn.end": # Update the offset compensation for the next SSML request. - offset_compensation = last_duration_offset + self.state["offset_compensation"] = self.state[ + "last_duration_offset" + ] # Use average padding typically added by the service # to the end of the audio data. This seems to work pretty # well for now, but we might ultimately need to use a # more sophisticated method like using ffmpeg to get # the actual duration of the audio data. - offset_compensation += 8_750_000 + self.state["offset_compensation"] += 8_750_000 - # Send the next SSML request to the service. - if not await send_ssml_request(): - break + # Exit the loop so we can send the next SSML request. + break elif path not in (b"response", b"turn.start"): - raise UnknownResponse( - "The response from the service is not recognized.\n" - + received.data - ) + raise UnknownResponse("Unknown path received") elif received.type == aiohttp.WSMsgType.BINARY: + # Message is too short to contain header length. if len(received.data) < 2: raise UnexpectedResponse( "We received a binary message, but it is missing the header length." ) + # The first two bytes of the binary message contain the header length. header_length = int.from_bytes(received.data[:2], "big") - if len(received.data) < header_length + 2: + if header_length > len(received.data): + raise UnexpectedResponse( + "The header length is greater than the length of the data." + ) + + # Parse the headers and data from the binary message. + parameters, data = get_headers_and_data( + received.data, header_length + ) + + # Check if the path is audio. + if parameters.get(b"Path") != b"audio": + raise UnexpectedResponse( + "Received binary message, but the path is not audio." + ) + + # At termination of the stream, the service sends a binary message + # with no Content-Type; this is expected. What is not expected is for + # an MPEG audio stream to be sent with no data. + content_type = parameters.get(b"Content-Type", None) + if content_type not in [b"audio/mpeg", None]: + raise UnexpectedResponse( + "Received binary message, but with an unexpected Content-Type." + ) + + # We only allow no Content-Type if there is no data. + if content_type is None: + if len(data) == 0: + continue + + # If the data is not empty, then we need to raise an exception. + raise UnexpectedResponse( + "Received binary message with no Content-Type, but with data." + ) + + # If the data is empty now, then we need to raise an exception. + if len(data) == 0: raise UnexpectedResponse( - "We received a binary message, but it is missing the audio data." + "Received binary message, but it is missing the audio data." ) + # Yield the audio data. audio_was_received = True - yield { - "type": "audio", - "data": received.data[header_length + 2 :], - } + yield {"type": "audio", "data": data} elif received.type == aiohttp.WSMsgType.ERROR: raise WebSocketError( received.data if received.data else "Unknown error" @@ -482,6 +483,29 @@ def parse_metadata() -> Dict[str, Any]: "No audio was received. Please verify that your parameters are correct." ) + async def stream( + self, + ) -> AsyncGenerator[Dict[str, Any], None]: + """ + Streams audio and metadata from the service. + + Raises: + NoAudioReceived: If no audio is received from the service. + UnexpectedResponse: If the response from the service is unexpected. + UnknownResponse: If the response from the service is unknown. + WebSocketError: If there is an error with the websocket. + """ + + # Check if stream was called before. + if self.state["stream_was_called"]: + raise RuntimeError("stream can only be called once.") + self.state["stream_was_called"] = True + + # Stream the audio and metadata from the service. + for self.state["partial_text"] in self.texts: + async for message in self.__stream(): + yield message + async def save( self, audio_fname: Union[str, bytes], diff --git a/src/edge_tts/models.py b/src/edge_tts/models.py new file mode 100644 index 0000000..2265a8b --- /dev/null +++ b/src/edge_tts/models.py @@ -0,0 +1,70 @@ +"""Models for the Edge TTS module.""" + +import re +from dataclasses import dataclass + + +@dataclass +class TTSConfig: + """ + Represents the internal TTS configuration for Edge TTS's communicate class. + """ + + voice: str + rate: str + volume: str + pitch: str + + @staticmethod + def validate_string_param(param_name: str, param_value: str, pattern: str) -> str: + """ + Validates the given string parameter based on type and pattern. + + Args: + param_name (str): The name of the parameter. + param_value (str): The value of the parameter. + pattern (str): The pattern to validate the parameter against. + + Returns: + str: The validated parameter. + """ + if not isinstance(param_value, str): + raise TypeError(f"{param_name} must be str") + if re.match(pattern, param_value) is None: + raise ValueError(f"Invalid {param_name} '{param_value}'.") + return param_value + + def __post_init__(self) -> None: + """ + Validates the TTSConfig object after initialization. + """ + + # Possible values for voice are: + # - Microsoft Server Speech Text to Speech Voice (cy-GB, NiaNeural) + # - cy-GB-NiaNeural + # - fil-PH-AngeloNeural + # Always send the first variant as that is what Microsoft Edge does. + if not isinstance(self.voice, str): + raise TypeError("voice must be str") + match = re.match(r"^([a-z]{2,})-([A-Z]{2,})-(.+Neural)$", self.voice) + if match is not None: + lang = match.group(1) + region = match.group(2) + name = match.group(3) + if name.find("-") != -1: + region = region + "-" + name[: name.find("-")] + name = name[name.find("-") + 1 :] + self.voice = ( + "Microsoft Server Speech Text to Speech Voice" + + f" ({lang}-{region}, {name})" + ) + + # Validate the rate, volume, and pitch parameters. + self.validate_string_param( + "voice", + self.voice, + r"^Microsoft Server Speech Text to Speech Voice \(.+,.+\)$", + ) + self.validate_string_param("rate", self.rate, r"^[+-]\d+%$") + self.validate_string_param("volume", self.volume, r"^[+-]\d+%$") + self.validate_string_param("pitch", self.pitch, r"^[+-]\d+Hz$") diff --git a/src/edge_tts/util.py b/src/edge_tts/util.py index 7f674e5..532a140 100644 --- a/src/edge_tts/util.py +++ b/src/edge_tts/util.py @@ -8,7 +8,7 @@ from io import TextIOWrapper from typing import Any, TextIO, Union -from edge_tts import Communicate, SubMaker, list_voices +from . import Communicate, SubMaker, list_voices async def _print_voices(*, proxy: str) -> None: