Skip to content

Commit

Permalink
Improve type support a bit more (#333)
Browse files Browse the repository at this point in the history
Also fix default voice for util.py

Signed-off-by: rany <[email protected]>
  • Loading branch information
rany2 authored Nov 23, 2024
1 parent a3d468c commit 0639576
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 32 deletions.
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,3 @@ warn_unreachable = True

strict_equality = True
strict = True

[mypy-edge_tts.voices]
disallow_any_decorated = False
18 changes: 7 additions & 11 deletions src/edge_tts/communicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from io import TextIOWrapper
from queue import Queue
from typing import (
Any,
AsyncGenerator,
ContextManager,
Dict,
Expand All @@ -26,16 +25,16 @@
import aiohttp
import certifi

from .constants import SEC_MS_GEC_VERSION, WSS_HEADERS, WSS_URL
from .constants import DEFAULT_VOICE, SEC_MS_GEC_VERSION, WSS_HEADERS, WSS_URL
from .data_classes import TTSConfig
from .drm import DRM
from .exceptions import (
NoAudioReceived,
UnexpectedResponse,
UnknownResponse,
WebSocketError,
)
from .models import TTSConfig
from .typing import TTSChunk
from .typing import CommunicateState, TTSChunk


def get_headers_and_data(
Expand Down Expand Up @@ -109,7 +108,7 @@ def split_text_by_byte_length(
text will be inside of an XML tag.
Args:
text (str or bytes): The string to be split.
text (str or bytes): The string to be split. If bytes, it must be UTF-8 encoded.
byte_length (int): The maximum byte length of each string in the list.
Yield:
Expand Down Expand Up @@ -166,12 +165,9 @@ def mkssml(tc: TTSConfig, escaped_text: Union[str, bytes]) -> str:
Returns:
str: The SSML string.
"""

# If the text is bytes, convert it to a string.
if isinstance(escaped_text, bytes):
escaped_text = escaped_text.decode("utf-8")

# Return the SSML string.
return (
"<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='en-US'>"
f"<voice name='{tc.voice}'>"
Expand Down Expand Up @@ -244,7 +240,7 @@ class Communicate:
def __init__(
self,
text: str,
voice: str = "en-US-EmmaMultilingualNeural",
voice: str = DEFAULT_VOICE,
*,
rate: str = "+0%",
volume: str = "+0%",
Expand Down Expand Up @@ -290,8 +286,8 @@ def __init__(
self.connector: Optional[aiohttp.BaseConnector] = connector

# Store current state of TTS.
self.state: Dict[str, Any] = {
"partial_text": None,
self.state: CommunicateState = {
"partial_text": b"",
"offset_compensation": 0,
"last_duration_offset": 0,
"stream_was_called": False,
Expand Down
2 changes: 2 additions & 0 deletions src/edge_tts/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
WSS_URL = f"wss://{BASE_URL}/edge/v1?TrustedClientToken={TRUSTED_CLIENT_TOKEN}"
VOICE_LIST = f"https://{BASE_URL}/voices/list?trustedclienttoken={TRUSTED_CLIENT_TOKEN}"

DEFAULT_VOICE = "en-US-EmmaMultilingualNeural"

CHROMIUM_FULL_VERSION = "130.0.2849.68"
CHROMIUM_MAJOR_VERSION = CHROMIUM_FULL_VERSION.split(".", maxsplit=1)[0]
SEC_MS_GEC_VERSION = f"1-{CHROMIUM_FULL_VERSION}"
Expand Down
21 changes: 19 additions & 2 deletions src/edge_tts/models.py → src/edge_tts/data_classes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""This module contains the TTSConfig dataclass, which represents the
internal TTS configuration for edge-tts's Communicate class."""
"""Data models for edge-tts."""

# pylint: disable=too-few-public-methods

import argparse
import re
from dataclasses import dataclass

Expand Down Expand Up @@ -69,3 +71,18 @@ def __post_init__(self) -> None:
self.validate_string_param("rate", self.rate, r"^[+-]\d+%$")
self.validate_string_param("volume", self.volume, r"^[+-]\d+%$")
self.validate_string_param("pitch", self.pitch, r"^[+-]\d+Hz$")


class UtilArgs(argparse.Namespace):
"""CLI arguments."""

text: str
file: str
voice: str
list_voices: bool
rate: str
volume: str
pitch: str
write_media: str
write_subtitles: str
proxy: str
17 changes: 13 additions & 4 deletions src/edge_tts/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,24 @@ class Voice(TypedDict):
VoiceTag: VoiceTag


class VoiceManagerVoice(Voice):
"""Voice data for VoiceManager."""
class VoicesManagerVoice(Voice):
"""Voice data for VoicesManager."""

Language: str


class VoiceManagerFind(TypedDict):
"""Voice data for VoiceManager.find()."""
class VoicesManagerFind(TypedDict):
"""Voice data for VoicesManager.find()."""

Gender: NotRequired[Literal["Female", "Male"]]
Locale: NotRequired[str]
Language: NotRequired[str]


class CommunicateState(TypedDict):
"""Communicate state data."""

partial_text: bytes
offset_compensation: float
last_duration_offset: float
stream_was_called: bool
18 changes: 11 additions & 7 deletions src/edge_tts/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
import argparse
import asyncio
import sys
from typing import Any, Optional, TextIO
from typing import Optional, TextIO

from tabulate import tabulate

from . import Communicate, SubMaker, list_voices
from .constants import DEFAULT_VOICE
from .data_classes import UtilArgs


async def _print_voices(*, proxy: str) -> None:
async def _print_voices(*, proxy: Optional[str]) -> None:
"""Print all available voices."""
voices = await list_voices(proxy=proxy)
voices = sorted(voices, key=lambda voice: voice["ShortName"])
Expand All @@ -27,7 +29,7 @@ async def _print_voices(*, proxy: str) -> None:
print(tabulate(table, headers))


async def _run_tts(args: Any) -> None:
async def _run_tts(args: UtilArgs) -> None:
"""Run TTS after parsing arguments from command line."""

try:
Expand Down Expand Up @@ -84,15 +86,17 @@ async def _run_tts(args: Any) -> None:

async def amain() -> None:
"""Async main function"""
parser = argparse.ArgumentParser(description="Microsoft Edge TTS")
parser = argparse.ArgumentParser(
description="Text-to-speech using Microsoft Edge's online TTS service."
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("-t", "--text", help="what TTS will say")
group.add_argument("-f", "--file", help="same as --text but read from file")
parser.add_argument(
"-v",
"--voice",
help="voice for TTS. Default: en-US-AriaNeural",
default="en-US-AriaNeural",
help=f"voice for TTS. Default: {DEFAULT_VOICE}",
default=DEFAULT_VOICE,
)
group.add_argument(
"-l",
Expand All @@ -111,7 +115,7 @@ async def amain() -> None:
help="send subtitle output to provided file instead of stderr",
)
parser.add_argument("--proxy", help="use a proxy for TTS and voice list.")
args = parser.parse_args()
args = parser.parse_args(namespace=UtilArgs())

if args.list_voices:
await _print_voices(proxy=args.proxy)
Expand Down
10 changes: 5 additions & 5 deletions src/edge_tts/voices.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@

import json
import ssl
from typing import Any, List, Optional
from typing import List, Optional

import aiohttp
import certifi
from typing_extensions import Unpack

from .constants import SEC_MS_GEC_VERSION, VOICE_HEADERS, VOICE_LIST
from .drm import DRM
from .typing import Voice, VoiceManagerFind, VoiceManagerVoice
from .typing import Voice, VoicesManagerFind, VoicesManagerVoice


async def __list_voices(
Expand Down Expand Up @@ -91,12 +91,12 @@ class VoicesManager:
"""

def __init__(self) -> None:
self.voices: List[VoiceManagerVoice] = []
self.voices: List[VoicesManagerVoice] = []
self.called_create: bool = False

@classmethod
async def create(
cls: Any, custom_voices: Optional[List[Voice]] = None
cls, custom_voices: Optional[List[Voice]] = None
) -> "VoicesManager":
"""
Creates a VoicesManager object and populates it with all available voices.
Expand All @@ -109,7 +109,7 @@ async def create(
self.called_create = True
return self

def find(self, **kwargs: Unpack[VoiceManagerFind]) -> List[VoiceManagerVoice]:
def find(self, **kwargs: Unpack[VoicesManagerFind]) -> List[VoicesManagerVoice]:
"""
Finds all matching voices based on the provided attributes.
"""
Expand Down

0 comments on commit 0639576

Please sign in to comment.