diff --git a/src/edge_tts/communicate.py b/src/edge_tts/communicate.py
index c53f007..36c5dff 100644
--- a/src/edge_tts/communicate.py
+++ b/src/edge_tts/communicate.py
@@ -5,7 +5,6 @@
import asyncio
import concurrent.futures
import json
-import re
import ssl
import time
import uuid
@@ -28,37 +27,38 @@
import aiohttp
import certifi
-from edge_tts.exceptions import (
+from .constants import WSS_URL
+from .exceptions import (
NoAudioReceived,
UnexpectedResponse,
UnknownResponse,
WebSocketError,
)
-
-from .constants import WSS_URL
+from .models import TTSConfig
-def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[bytes, bytes], bytes]:
+def get_headers_and_data(
+ data: bytes, header_length: int
+) -> Tuple[Dict[bytes, bytes], bytes]:
"""
Returns the headers and data from the given data.
Args:
- data (str or bytes): The data to be parsed.
+ data (bytes): The data to be parsed.
+ header_length (int): The length of the header.
Returns:
tuple: The headers and data to be used in the request.
"""
- if isinstance(data, str):
- data = data.encode("utf-8")
if not isinstance(data, bytes):
- raise TypeError("data must be str or bytes")
+ raise TypeError("data must be bytes")
headers = {}
- for line in data[: data.find(b"\r\n\r\n")].split(b"\r\n"):
+ for line in data[:header_length].split(b"\r\n"):
key, value = line.split(b":", 1)
headers[key] = value
- return headers, data[data.find(b"\r\n\r\n") + 4 :]
+ return headers, data[header_length + 2 :]
def remove_incompatible_characters(string: Union[str, bytes]) -> str:
@@ -154,24 +154,32 @@ def split_text_by_byte_length(
yield new_text
-def mkssml(
- text: Union[str, bytes], voice: str, rate: str, volume: str, pitch: str
-) -> str:
+def mkssml(tc: TTSConfig, escaped_text: Union[str, bytes]) -> str:
"""
Creates a SSML string from the given parameters.
+ Args:
+ tc (TTSConfig): The TTS configuration.
+ escaped_text (str or bytes): The escaped text. If bytes, it must be UTF-8 encoded.
+
Returns:
str: The SSML string.
"""
- if isinstance(text, bytes):
- text = text.decode("utf-8")
- ssml = (
+ # If the text is bytes, convert it to a string.
+ if isinstance(escaped_text, bytes):
+ escaped_text = escaped_text.decode("utf-8")
+
+ # Return the SSML string.
+ return (
""
- f""
- f"{text}"
+ f""
+ f""
+ f"{escaped_text}"
+ ""
+ ""
+ ""
)
- return ssml
def date_to_string() -> str:
@@ -207,7 +215,7 @@ def ssml_headers_plus_data(request_id: str, timestamp: str, ssml: str) -> str:
)
-def calc_max_mesg_size(voice: str, rate: str, volume: str, pitch: str) -> int:
+def calc_max_mesg_size(tts_config: TTSConfig) -> int:
"""Calculates the maximum message size for the given voice, rate, and volume.
Returns:
@@ -219,7 +227,7 @@ def calc_max_mesg_size(voice: str, rate: str, volume: str, pitch: str) -> int:
ssml_headers_plus_data(
connect_id(),
date_to_string(),
- mkssml("", voice, rate, volume, pitch),
+ mkssml(tts_config, ""),
)
)
+ 50 # margin of error
@@ -232,25 +240,6 @@ class Communicate:
Class for communicating with the service.
"""
- @staticmethod
- def validate_string_param(param_name: str, param_value: str, pattern: str) -> str:
- """
- Validates the given string parameter based on type and pattern.
-
- Args:
- param_name (str): The name of the parameter.
- param_value (str): The value of the parameter.
- pattern (str): The pattern to validate the parameter against.
-
- Returns:
- str: The validated parameter.
- """
- if not isinstance(param_value, str):
- raise TypeError(f"{param_name} must be str")
- if re.match(pattern, param_value) is None:
- raise ValueError(f"Invalid {param_name} '{param_value}'.")
- return param_value
-
def __init__(
self,
text: str,
@@ -269,46 +258,30 @@ def __init__(
Raises:
ValueError: If the voice is not valid.
"""
+
+ # Validate TTS settings and store the TTSConfig object.
+ self.tts_config = TTSConfig(voice, rate, volume, pitch)
+
+ # Validate the text parameter.
if not isinstance(text, str):
raise TypeError("text must be str")
- self.text: str = text
-
- # Possible values for voice are:
- # - Microsoft Server Speech Text to Speech Voice (cy-GB, NiaNeural)
- # - cy-GB-NiaNeural
- # - fil-PH-AngeloNeural
- # Always send the first variant as that is what Microsoft Edge does.
- if not isinstance(voice, str):
- raise TypeError("voice must be str")
- self.voice: str = voice
- match = re.match(r"^([a-z]{2,})-([A-Z]{2,})-(.+Neural)$", voice)
- if match is not None:
- lang = match.group(1)
- region = match.group(2)
- name = match.group(3)
- if name.find("-") != -1:
- region = region + "-" + name[: name.find("-")]
- name = name[name.find("-") + 1 :]
- self.voice = (
- "Microsoft Server Speech Text to Speech Voice"
- + f" ({lang}-{region}, {name})"
- )
- self.voice = self.validate_string_param(
- "voice",
- self.voice,
- r"^Microsoft Server Speech Text to Speech Voice \(.+,.+\)$",
+ # Split the text into multiple strings and store them.
+ self.texts = split_text_by_byte_length(
+ escape(remove_incompatible_characters(text)),
+ calc_max_mesg_size(self.tts_config),
)
- self.rate = self.validate_string_param("rate", rate, r"^[+-]\d+%$")
- self.volume = self.validate_string_param("volume", volume, r"^[+-]\d+%$")
- self.pitch = self.validate_string_param("pitch", pitch, r"^[+-]\d+Hz$")
+ # Validate the proxy parameter.
if proxy is not None and not isinstance(proxy, str):
raise TypeError("proxy must be str")
self.proxy: Optional[str] = proxy
- if not isinstance(connect_timeout, int) or not isinstance(receive_timeout, int):
- raise TypeError("connect_timeout and receive_timeout must be int")
+ # Validate the timeout parameters.
+ if not isinstance(connect_timeout, int):
+ raise TypeError("connect_timeout must be int")
+ if not isinstance(receive_timeout, int):
+ raise TypeError("receive_timeout must be int")
self.session_timeout = aiohttp.ClientTimeout(
total=None,
connect=None,
@@ -316,9 +289,34 @@ def __init__(
sock_read=receive_timeout,
)
- async def stream(self) -> AsyncGenerator[Dict[str, Any], None]:
- """Streams audio and metadata from the service."""
-
+ # Store current state of TTS.
+ self.state: Dict[str, Any] = {
+ "partial_text": None,
+ "offset_compensation": 0,
+ "last_duration_offset": 0,
+ "stream_was_called": False,
+ }
+
+ def __parse_metadata(self, data: bytes) -> Dict[str, Any]:
+ for meta_obj in json.loads(data)["Metadata"]:
+ meta_type = meta_obj["Type"]
+ if meta_type == "WordBoundary":
+ current_offset = (
+ meta_obj["Data"]["Offset"] + self.state["offset_compensation"]
+ )
+ current_duration = meta_obj["Data"]["Duration"]
+ return {
+ "type": meta_type,
+ "offset": current_offset,
+ "duration": current_duration,
+ "text": meta_obj["Data"]["text"]["Text"],
+ }
+ if meta_type in ("SessionEnd",):
+ continue
+ raise UnknownResponse(f"Unknown metadata type: {meta_type}")
+ raise UnexpectedResponse("No WordBoundary metadata found")
+
+ async def __stream(self) -> AsyncGenerator[Dict[str, Any], None]:
async def send_command_request() -> None:
"""Sends the request to the service."""
@@ -342,55 +340,25 @@ async def send_command_request() -> None:
"}}}}\r\n"
)
- async def send_ssml_request() -> bool:
+ async def send_ssml_request() -> None:
"""Sends the SSML request to the service."""
- # Get the next string from the generator.
- text = next(texts, None)
-
- # If there are no more strings, return False.
- if text is None:
- return False
-
- # Send the request to the service and return True.
+ # Send the request to the service.
await websocket.send_str(
ssml_headers_plus_data(
connect_id(),
date_to_string(),
- mkssml(text, self.voice, self.rate, self.volume, self.pitch),
+ mkssml(
+ self.tts_config,
+ self.state["partial_text"],
+ ),
)
)
- return True
-
- def parse_metadata() -> Dict[str, Any]:
- for meta_obj in json.loads(data)["Metadata"]:
- meta_type = meta_obj["Type"]
- if meta_type == "WordBoundary":
- current_offset = meta_obj["Data"]["Offset"] + offset_compensation
- current_duration = meta_obj["Data"]["Duration"]
- return {
- "type": meta_type,
- "offset": current_offset,
- "duration": current_duration,
- "text": meta_obj["Data"]["text"]["Text"],
- }
- if meta_type in ("SessionEnd",):
- continue
- raise UnknownResponse(f"Unknown metadata type: {meta_type}")
- raise UnexpectedResponse("No WordBoundary metadata found")
-
- # Split the text into multiple strings if it is too long for the service.
- texts = split_text_by_byte_length(
- escape(remove_incompatible_characters(self.text)),
- calc_max_mesg_size(self.voice, self.rate, self.volume, self.pitch),
- )
-
- # Keep track of last duration + offset to calculate the offset
- # upon word split.
- last_duration_offset = 0
- # Current offset compensations.
- offset_compensation = 0
+ # audio_was_received indicates whether we have received audio data
+ # from the websocket. This is so we can raise an exception if we
+ # don't receive any audio data.
+ audio_was_received = False
# Create a new connection to the service.
ssl_ctx = ssl.create_default_context(cafile=certifi.where())
@@ -412,11 +380,6 @@ def parse_metadata() -> Dict[str, Any]:
},
ssl=ssl_ctx,
) as websocket:
- # audio_was_received indicates whether we have received audio data
- # from the websocket. This is so we can raise an exception if we
- # don't receive any audio data.
- audio_was_received = False
-
# Send the request to the service.
await send_command_request()
@@ -425,53 +388,91 @@ def parse_metadata() -> Dict[str, Any]:
async for received in websocket:
if received.type == aiohttp.WSMsgType.TEXT:
- parameters, data = get_headers_and_data(received.data)
- path = parameters.get(b"Path")
+ encoded_data: bytes = received.data.encode("utf-8")
+ parameters, data = get_headers_and_data(
+ encoded_data, encoded_data.find(b"\r\n\r\n")
+ )
+
+ path = parameters.get(b"Path", None)
if path == b"audio.metadata":
# Parse the metadata and yield it.
- parsed_metadata = parse_metadata()
+ parsed_metadata = self.__parse_metadata(data)
yield parsed_metadata
# Update the last duration offset for use by the next SSML request.
- last_duration_offset = (
+ self.state["last_duration_offset"] = (
parsed_metadata["offset"] + parsed_metadata["duration"]
)
elif path == b"turn.end":
# Update the offset compensation for the next SSML request.
- offset_compensation = last_duration_offset
+ self.state["offset_compensation"] = self.state[
+ "last_duration_offset"
+ ]
# Use average padding typically added by the service
# to the end of the audio data. This seems to work pretty
# well for now, but we might ultimately need to use a
# more sophisticated method like using ffmpeg to get
# the actual duration of the audio data.
- offset_compensation += 8_750_000
+ self.state["offset_compensation"] += 8_750_000
- # Send the next SSML request to the service.
- if not await send_ssml_request():
- break
+ # Exit the loop so we can send the next SSML request.
+ break
elif path not in (b"response", b"turn.start"):
- raise UnknownResponse(
- "The response from the service is not recognized.\n"
- + received.data
- )
+ raise UnknownResponse("Unknown path received")
elif received.type == aiohttp.WSMsgType.BINARY:
+ # Message is too short to contain header length.
if len(received.data) < 2:
raise UnexpectedResponse(
"We received a binary message, but it is missing the header length."
)
+ # The first two bytes of the binary message contain the header length.
header_length = int.from_bytes(received.data[:2], "big")
- if len(received.data) < header_length + 2:
+ if header_length > len(received.data):
+ raise UnexpectedResponse(
+ "The header length is greater than the length of the data."
+ )
+
+ # Parse the headers and data from the binary message.
+ parameters, data = get_headers_and_data(
+ received.data, header_length
+ )
+
+ # Check if the path is audio.
+ if parameters.get(b"Path") != b"audio":
+ raise UnexpectedResponse(
+ "Received binary message, but the path is not audio."
+ )
+
+ # At termination of the stream, the service sends a binary message
+ # with no Content-Type; this is expected. What is not expected is for
+ # an MPEG audio stream to be sent with no data.
+ content_type = parameters.get(b"Content-Type", None)
+ if content_type not in [b"audio/mpeg", None]:
+ raise UnexpectedResponse(
+ "Received binary message, but with an unexpected Content-Type."
+ )
+
+ # We only allow no Content-Type if there is no data.
+ if content_type is None:
+ if len(data) == 0:
+ continue
+
+ # If the data is not empty, then we need to raise an exception.
+ raise UnexpectedResponse(
+ "Received binary message with no Content-Type, but with data."
+ )
+
+ # If the data is empty now, then we need to raise an exception.
+ if len(data) == 0:
raise UnexpectedResponse(
- "We received a binary message, but it is missing the audio data."
+ "Received binary message, but it is missing the audio data."
)
+ # Yield the audio data.
audio_was_received = True
- yield {
- "type": "audio",
- "data": received.data[header_length + 2 :],
- }
+ yield {"type": "audio", "data": data}
elif received.type == aiohttp.WSMsgType.ERROR:
raise WebSocketError(
received.data if received.data else "Unknown error"
@@ -482,6 +483,29 @@ def parse_metadata() -> Dict[str, Any]:
"No audio was received. Please verify that your parameters are correct."
)
+ async def stream(
+ self,
+ ) -> AsyncGenerator[Dict[str, Any], None]:
+ """
+ Streams audio and metadata from the service.
+
+ Raises:
+ NoAudioReceived: If no audio is received from the service.
+ UnexpectedResponse: If the response from the service is unexpected.
+ UnknownResponse: If the response from the service is unknown.
+ WebSocketError: If there is an error with the websocket.
+ """
+
+ # Check if stream was called before.
+ if self.state["stream_was_called"]:
+ raise RuntimeError("stream can only be called once.")
+ self.state["stream_was_called"] = True
+
+ # Stream the audio and metadata from the service.
+ for self.state["partial_text"] in self.texts:
+ async for message in self.__stream():
+ yield message
+
async def save(
self,
audio_fname: Union[str, bytes],
diff --git a/src/edge_tts/models.py b/src/edge_tts/models.py
new file mode 100644
index 0000000..2265a8b
--- /dev/null
+++ b/src/edge_tts/models.py
@@ -0,0 +1,70 @@
+"""Models for the Edge TTS module."""
+
+import re
+from dataclasses import dataclass
+
+
+@dataclass
+class TTSConfig:
+ """
+ Represents the internal TTS configuration for Edge TTS's communicate class.
+ """
+
+ voice: str
+ rate: str
+ volume: str
+ pitch: str
+
+ @staticmethod
+ def validate_string_param(param_name: str, param_value: str, pattern: str) -> str:
+ """
+ Validates the given string parameter based on type and pattern.
+
+ Args:
+ param_name (str): The name of the parameter.
+ param_value (str): The value of the parameter.
+ pattern (str): The pattern to validate the parameter against.
+
+ Returns:
+ str: The validated parameter.
+ """
+ if not isinstance(param_value, str):
+ raise TypeError(f"{param_name} must be str")
+ if re.match(pattern, param_value) is None:
+ raise ValueError(f"Invalid {param_name} '{param_value}'.")
+ return param_value
+
+ def __post_init__(self) -> None:
+ """
+ Validates the TTSConfig object after initialization.
+ """
+
+ # Possible values for voice are:
+ # - Microsoft Server Speech Text to Speech Voice (cy-GB, NiaNeural)
+ # - cy-GB-NiaNeural
+ # - fil-PH-AngeloNeural
+ # Always send the first variant as that is what Microsoft Edge does.
+ if not isinstance(self.voice, str):
+ raise TypeError("voice must be str")
+ match = re.match(r"^([a-z]{2,})-([A-Z]{2,})-(.+Neural)$", self.voice)
+ if match is not None:
+ lang = match.group(1)
+ region = match.group(2)
+ name = match.group(3)
+ if name.find("-") != -1:
+ region = region + "-" + name[: name.find("-")]
+ name = name[name.find("-") + 1 :]
+ self.voice = (
+ "Microsoft Server Speech Text to Speech Voice"
+ + f" ({lang}-{region}, {name})"
+ )
+
+ # Validate the rate, volume, and pitch parameters.
+ self.validate_string_param(
+ "voice",
+ self.voice,
+ r"^Microsoft Server Speech Text to Speech Voice \(.+,.+\)$",
+ )
+ self.validate_string_param("rate", self.rate, r"^[+-]\d+%$")
+ self.validate_string_param("volume", self.volume, r"^[+-]\d+%$")
+ self.validate_string_param("pitch", self.pitch, r"^[+-]\d+Hz$")
diff --git a/src/edge_tts/util.py b/src/edge_tts/util.py
index 7f674e5..532a140 100644
--- a/src/edge_tts/util.py
+++ b/src/edge_tts/util.py
@@ -8,7 +8,7 @@
from io import TextIOWrapper
from typing import Any, TextIO, Union
-from edge_tts import Communicate, SubMaker, list_voices
+from . import Communicate, SubMaker, list_voices
async def _print_voices(*, proxy: str) -> None: