Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speech to visemes as a child of BaseHandler #5

Merged
merged 3 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
__pycache__
tmp
cache
mlx_models/
mlx_models/
asset/
config/
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ This repository implements a speech-to-speech cascaded pipeline consisting of th
2. **Speech to Text (STT)**
3. **Language Model (LM)**
4. **Text to Speech (TTS)**
5. **Speech to Visemes (STV)**

### Modularity
The pipeline provides a fully open and modular approach, with a focus on leveraging models available through the Transformers library on the Hugging Face hub. The code is designed for easy modification, and we already support device-specific and external library implementations:
Expand All @@ -50,6 +51,9 @@ The pipeline provides a fully open and modular approach, with a focus on leverag
- [MeloTTS](https://github.com/myshell-ai/MeloTTS)
- [ChatTTS](https://github.com/2noise/ChatTTS?tab=readme-ov-file)

**STV**
- [Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/en/model_doc/wav2vec2_phoneme) + [Phoneme to viseme mapping](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/how-to-speech-synthesis-viseme?tabs=visemeid&pivots=programming-language-python#map-phonemes-to-visemes)

## Setup

Clone the repository:
Expand Down Expand Up @@ -216,6 +220,13 @@ For example:
--lm_model_name google/gemma-2b-it
```


### STV parameters
See [Wav2Vec2STVHandlerArguments](arguments_classes/w2v_stv_arguments.py) class. Notably:
- `stv_model_name` is by default `bookbot/wav2vec2-ljspeech-gruut` and has been chosen because accurate and fast enough
- `stv_skip`, flag it to `True` if you don't need visemes


### Generation parameters

Other generation parameters of the model's generate method can be set using the part's prefix + `_gen_`, e.g., `--stt_gen_max_new_tokens 128`. These parameters can be added to the pipeline part's arguments class if not already exposed.
Expand Down
1 change: 0 additions & 1 deletion STT/paraformer_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def setup(
device="cuda",
gen_kwargs={},
):
print(model_name)
if len(model_name.split("/")) > 1:
model_name = model_name.split("/")[-1]
self.device = device
Expand Down
File renamed without changes.
253 changes: 253 additions & 0 deletions STV/w2v_stv_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
import json
import logging
import time
from typing import Any, Dict, Generator, List

import numpy as np
from rich.console import Console
from transformers import pipeline

from baseHandler import BaseHandler

logger = logging.getLogger(__name__)
console = Console()


class Wav2Vec2STVHandler(BaseHandler):
"""
Handles the Speech-To-Viseme generation using a Wav2Vec2 model for automatic
speech recognition (ASR) and phoneme mapping to visemes.

Attributes:
MIN_AUDIO_LENGTH (float): Minimum length of audio (in seconds) required
for phoneme extraction.
"""

MIN_AUDIO_LENGTH = 0.5 # Minimum audio length in seconds for phoneme extraction

def setup(
self,
should_listen: bool,
model_name: str = "bookbot/wav2vec2-ljspeech-gruut",
blocksize: int = 512,
device: str = "cuda",
skip: bool = False,
gen_kwargs: Dict[str, Any] = {}, # Not used
) -> None:
"""
Initializes the handler by loading the ASR model and phoneme-to-viseme map.

Args:
should_listen (bool): Flag indicating whether the speech-to-speech pipeline should start
listening to the user or not.
model_name (str): Name of the ASR model to use.
Defaults to "bookbot/wav2vec2-ljspeech-gruut".
blocksize (int): Size of each audio block when processing audio.
Defaults to 512.
device (str): Device to run the model on ("cuda", "mps", or "cpu").
Defaults to "cuda".
skip (bool): If True, the speech-to-viseme process is skipped.
Defaults to False.
gen_kwargs (dict): Additional parameters for speech generation.

Returns:
None
"""
self.device = device
self.gen_kwargs = gen_kwargs
self.blocksize = blocksize
self.should_listen = should_listen
self.skip = skip

# Load phoneme-to-viseme map from the JSON file
# inspired by https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-ssml-phonetic-sets
phoneme_viseme_map_file = "STV/phoneme_viseme_map.json"
with open(phoneme_viseme_map_file, "r") as f:
self.phoneme_viseme_map = json.load(f)

# Initialize the ASR pipeline using the specified model and device
self.asr_pipeline = pipeline(
"automatic-speech-recognition",
model=model_name,
device=device,
torch_dtype="auto",
)
self.expected_sampling_rate = self.asr_pipeline.feature_extractor.sampling_rate

# Initialize an empty dictionary to store audio batch data
self.audio_batch = {
"waveform": np.array([]),
"sampling_rate": self.expected_sampling_rate,
}
self.text_batch = None
self.should_listen_flag = False

self.warmup() # Perform model warmup

def warmup(self) -> None:
"""Warms up the model with dummy input to prepare it for inference.

Returns:
None
"""
logger.info(f"Warming up {self.__class__.__name__}")
start_time = time.time()

# Create dummy input for warmup inference
dummy_input = np.random.randn(self.blocksize).astype(np.int16)
_ = self.speech_to_visemes(dummy_input)

warmup_time = time.time() - start_time
logger.info(
f"{self.__class__.__name__}: warmed up in {warmup_time:.4f} seconds!"
)

def speech_to_visemes(self, audio: Any) -> List[Dict[str, Any]]:
"""
Converts speech audio to visemes by performing Automatic Speech Recognition (ASR)
and mapping phonemes to visemes.

Args:
audio (Any): The input audio data.

Returns:
List[Dict[str, Any]]: A list of dictionaries containing mapped visemes
and their corresponding timestamps.

Note:
Heuristically, the input audio should be at least 0.5 seconds long for proper phoneme extraction.
"""

def _map_phonemes_to_visemes(
data: Dict[str, Any],
) -> List[Dict[str, Any]]:
"""
Maps extracted phonemes to their corresponding visemes based on a predefined map.

Args:
data (Dict[str, Any]): Dictionary containing phoneme data where data['chunks']
holds a list of phonemes and their timestamps.

Returns:
List[Dict[str, Any]]: A list of dictionaries with viseme IDs and their corresponding timestamps.
"""
viseme_list = []
chunks = data.get("chunks", [])

# Map each phoneme to corresponding visemes
for chunk in chunks:
phoneme = chunk.get("text", None)
timestamp = chunk.get("timestamp", None)
visemes = self.phoneme_viseme_map.get(phoneme, [])

for viseme in visemes:
viseme_list.append({"viseme": viseme, "timestamp": timestamp})

return viseme_list

# Perform ASR to extract phoneme data, including timestamps
try:
asr_result = self.asr_pipeline(audio, return_timestamps="char")
except Exception as e:
logger.error(f"ASR error: {e}")
return []
# Map the phonemes obtained from ASR to visemes
return _map_phonemes_to_visemes(asr_result)

def process(self, data: Dict[str, Any]) -> Generator[Dict[str, Any], None, None]:
"""
Processes an audio file to generate visemes and output blocks of audio data
along with corresponding viseme data.

Args:
data (Dict[str, Any]): Dictionary containing audio, text, and potentially additional information.

Yields:
Dict: A dictionary containing audio waveform, and optionally viseme data, text, and potentially additional information.
"""

if "sentence_end" in data and data["sentence_end"]:
self.should_listen_flag = True
if self.skip: # Skip viseme extraction if the flag is set
yield {
"audio": {
"waveform": data["audio"]["waveform"],
"sampling_rate": data["audio"]["sampling_rate"],
},
"text": data["text"] if "text" in data else None,
}
else:
# Check if text data is present and save it for later
if "text" in data and data["text"] is not None:
self.text_batch = data["text"]
# Concatenate new audio data into the buffer if available and valid
if "audio" in data and data["audio"] is not None:
audio_data = data["audio"]
# Check if the sampling rate is valid and matches the expected one
if audio_data.get("sampling_rate", None) != self.expected_sampling_rate:
logger.error(
f"Expected sampling rate {self.expected_sampling_rate}, "
f"but got {audio_data['sampling_rate']}."
)
return
# Append the waveform to the audio buffer
self.audio_batch["waveform"] = np.concatenate(
(self.audio_batch["waveform"], audio_data["waveform"]), axis=0
)

# Ensure the total audio length is sufficient for phoneme extraction
if (
len(self.audio_batch["waveform"]) / self.audio_batch["sampling_rate"]
< self.MIN_AUDIO_LENGTH
):
return
else:
logger.debug("Starting viseme inference...")

# Perform viseme inference using the accumulated audio batch
viseme_data = self.speech_to_visemes(self.audio_batch["waveform"])
logger.debug("Viseme inference completed.")

# Print the visemes and timestamps to the console
for viseme in viseme_data:
console.print(
f"[blue]ASSISTANT_MOUTH_SHAPE: {viseme['viseme']} -- {viseme['timestamp']}"
)

# Process the audio in chunks of the defined blocksize
self.audio_batch["waveform"] = self.audio_batch["waveform"].astype(
np.int16
)
for i in range(0, len(self.audio_batch["waveform"]), self.blocksize):
chunk_waveform = self.audio_batch["waveform"][
i : i + self.blocksize
]
padded_waveform = np.pad(
chunk_waveform, (0, self.blocksize - len(chunk_waveform))
)

chunk_data = {
"audio": {
"waveform": padded_waveform,
"sample_rate": self.audio_batch["sampling_rate"],
}
}

# Add text and viseme data only in the first chunk
if i == 0:
if self.text_batch:
chunk_data["text"] = self.text_batch
if viseme_data and len(viseme_data) > 0:
chunk_data["visemes"] = viseme_data
yield chunk_data

# Reset the audio and text buffer after processing
self.audio_batch = {
"waveform": np.array([]),
"sampling_rate": self.expected_sampling_rate,
}
self.text_batch = ""

if self.should_listen_flag:
self.should_listen.set()
self.should_listen_flag = False
Loading