Skip to content

Commit

Permalink
Merge pull request #5 from sensein/flow
Browse files Browse the repository at this point in the history
Adding speech to visemes as a child of BaseHandler
  • Loading branch information
fabiocat93 authored Oct 8, 2024
2 parents 7176a1b + c7b85e1 commit 1bc8186
Show file tree
Hide file tree
Showing 14 changed files with 401 additions and 231 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
__pycache__
tmp
cache
mlx_models/
mlx_models/
asset/
config/
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ This repository implements a speech-to-speech cascaded pipeline consisting of th
2. **Speech to Text (STT)**
3. **Language Model (LM)**
4. **Text to Speech (TTS)**
5. **Speech to Visemes (STV)**

### Modularity
The pipeline provides a fully open and modular approach, with a focus on leveraging models available through the Transformers library on the Hugging Face hub. The code is designed for easy modification, and we already support device-specific and external library implementations:
Expand All @@ -50,6 +51,9 @@ The pipeline provides a fully open and modular approach, with a focus on leverag
- [MeloTTS](https://github.com/myshell-ai/MeloTTS)
- [ChatTTS](https://github.com/2noise/ChatTTS?tab=readme-ov-file)

**STV**
- [Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/en/model_doc/wav2vec2_phoneme) + [Phoneme to viseme mapping](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/how-to-speech-synthesis-viseme?tabs=visemeid&pivots=programming-language-python#map-phonemes-to-visemes)

## Setup

Clone the repository:
Expand Down Expand Up @@ -216,6 +220,13 @@ For example:
--lm_model_name google/gemma-2b-it
```
### STV parameters
See [Wav2Vec2STVHandlerArguments](arguments_classes/w2v_stv_arguments.py) class. Notably:
- `stv_model_name` is by default `bookbot/wav2vec2-ljspeech-gruut` and has been chosen because accurate and fast enough
- `stv_skip`, flag it to `True` if you don't need visemes
### Generation parameters
Other generation parameters of the model's generate method can be set using the part's prefix + `_gen_`, e.g., `--stt_gen_max_new_tokens 128`. These parameters can be added to the pipeline part's arguments class if not already exposed.
Expand Down
1 change: 0 additions & 1 deletion STT/paraformer_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def setup(
device="cuda",
gen_kwargs={},
):
print(model_name)
if len(model_name.split("/")) > 1:
model_name = model_name.split("/")[-1]
self.device = device
Expand Down
File renamed without changes.
253 changes: 253 additions & 0 deletions STV/w2v_stv_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
import json
import logging
import time
from typing import Any, Dict, Generator, List

import numpy as np
from rich.console import Console
from transformers import pipeline

from baseHandler import BaseHandler

logger = logging.getLogger(__name__)
console = Console()


class Wav2Vec2STVHandler(BaseHandler):
"""
Handles the Speech-To-Viseme generation using a Wav2Vec2 model for automatic
speech recognition (ASR) and phoneme mapping to visemes.
Attributes:
MIN_AUDIO_LENGTH (float): Minimum length of audio (in seconds) required
for phoneme extraction.
"""

MIN_AUDIO_LENGTH = 0.5 # Minimum audio length in seconds for phoneme extraction

def setup(
self,
should_listen: bool,
model_name: str = "bookbot/wav2vec2-ljspeech-gruut",
blocksize: int = 512,
device: str = "cuda",
skip: bool = False,
gen_kwargs: Dict[str, Any] = {}, # Not used
) -> None:
"""
Initializes the handler by loading the ASR model and phoneme-to-viseme map.
Args:
should_listen (bool): Flag indicating whether the speech-to-speech pipeline should start
listening to the user or not.
model_name (str): Name of the ASR model to use.
Defaults to "bookbot/wav2vec2-ljspeech-gruut".
blocksize (int): Size of each audio block when processing audio.
Defaults to 512.
device (str): Device to run the model on ("cuda", "mps", or "cpu").
Defaults to "cuda".
skip (bool): If True, the speech-to-viseme process is skipped.
Defaults to False.
gen_kwargs (dict): Additional parameters for speech generation.
Returns:
None
"""
self.device = device
self.gen_kwargs = gen_kwargs
self.blocksize = blocksize
self.should_listen = should_listen
self.skip = skip

# Load phoneme-to-viseme map from the JSON file
# inspired by https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-ssml-phonetic-sets
phoneme_viseme_map_file = "STV/phoneme_viseme_map.json"
with open(phoneme_viseme_map_file, "r") as f:
self.phoneme_viseme_map = json.load(f)

# Initialize the ASR pipeline using the specified model and device
self.asr_pipeline = pipeline(
"automatic-speech-recognition",
model=model_name,
device=device,
torch_dtype="auto",
)
self.expected_sampling_rate = self.asr_pipeline.feature_extractor.sampling_rate

# Initialize an empty dictionary to store audio batch data
self.audio_batch = {
"waveform": np.array([]),
"sampling_rate": self.expected_sampling_rate,
}
self.text_batch = None
self.should_listen_flag = False

self.warmup() # Perform model warmup

def warmup(self) -> None:
"""Warms up the model with dummy input to prepare it for inference.
Returns:
None
"""
logger.info(f"Warming up {self.__class__.__name__}")
start_time = time.time()

# Create dummy input for warmup inference
dummy_input = np.random.randn(self.blocksize).astype(np.int16)
_ = self.speech_to_visemes(dummy_input)

warmup_time = time.time() - start_time
logger.info(
f"{self.__class__.__name__}: warmed up in {warmup_time:.4f} seconds!"
)

def speech_to_visemes(self, audio: Any) -> List[Dict[str, Any]]:
"""
Converts speech audio to visemes by performing Automatic Speech Recognition (ASR)
and mapping phonemes to visemes.
Args:
audio (Any): The input audio data.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing mapped visemes
and their corresponding timestamps.
Note:
Heuristically, the input audio should be at least 0.5 seconds long for proper phoneme extraction.
"""

def _map_phonemes_to_visemes(
data: Dict[str, Any],
) -> List[Dict[str, Any]]:
"""
Maps extracted phonemes to their corresponding visemes based on a predefined map.
Args:
data (Dict[str, Any]): Dictionary containing phoneme data where data['chunks']
holds a list of phonemes and their timestamps.
Returns:
List[Dict[str, Any]]: A list of dictionaries with viseme IDs and their corresponding timestamps.
"""
viseme_list = []
chunks = data.get("chunks", [])

# Map each phoneme to corresponding visemes
for chunk in chunks:
phoneme = chunk.get("text", None)
timestamp = chunk.get("timestamp", None)
visemes = self.phoneme_viseme_map.get(phoneme, [])

for viseme in visemes:
viseme_list.append({"viseme": viseme, "timestamp": timestamp})

return viseme_list

# Perform ASR to extract phoneme data, including timestamps
try:
asr_result = self.asr_pipeline(audio, return_timestamps="char")
except Exception as e:
logger.error(f"ASR error: {e}")
return []
# Map the phonemes obtained from ASR to visemes
return _map_phonemes_to_visemes(asr_result)

def process(self, data: Dict[str, Any]) -> Generator[Dict[str, Any], None, None]:
"""
Processes an audio file to generate visemes and output blocks of audio data
along with corresponding viseme data.
Args:
data (Dict[str, Any]): Dictionary containing audio, text, and potentially additional information.
Yields:
Dict: A dictionary containing audio waveform, and optionally viseme data, text, and potentially additional information.
"""

if "sentence_end" in data and data["sentence_end"]:
self.should_listen_flag = True
if self.skip: # Skip viseme extraction if the flag is set
yield {
"audio": {
"waveform": data["audio"]["waveform"],
"sampling_rate": data["audio"]["sampling_rate"],
},
"text": data["text"] if "text" in data else None,
}
else:
# Check if text data is present and save it for later
if "text" in data and data["text"] is not None:
self.text_batch = data["text"]
# Concatenate new audio data into the buffer if available and valid
if "audio" in data and data["audio"] is not None:
audio_data = data["audio"]
# Check if the sampling rate is valid and matches the expected one
if audio_data.get("sampling_rate", None) != self.expected_sampling_rate:
logger.error(
f"Expected sampling rate {self.expected_sampling_rate}, "
f"but got {audio_data['sampling_rate']}."
)
return
# Append the waveform to the audio buffer
self.audio_batch["waveform"] = np.concatenate(
(self.audio_batch["waveform"], audio_data["waveform"]), axis=0
)

# Ensure the total audio length is sufficient for phoneme extraction
if (
len(self.audio_batch["waveform"]) / self.audio_batch["sampling_rate"]
< self.MIN_AUDIO_LENGTH
):
return
else:
logger.debug("Starting viseme inference...")

# Perform viseme inference using the accumulated audio batch
viseme_data = self.speech_to_visemes(self.audio_batch["waveform"])
logger.debug("Viseme inference completed.")

# Print the visemes and timestamps to the console
for viseme in viseme_data:
console.print(
f"[blue]ASSISTANT_MOUTH_SHAPE: {viseme['viseme']} -- {viseme['timestamp']}"
)

# Process the audio in chunks of the defined blocksize
self.audio_batch["waveform"] = self.audio_batch["waveform"].astype(
np.int16
)
for i in range(0, len(self.audio_batch["waveform"]), self.blocksize):
chunk_waveform = self.audio_batch["waveform"][
i : i + self.blocksize
]
padded_waveform = np.pad(
chunk_waveform, (0, self.blocksize - len(chunk_waveform))
)

chunk_data = {
"audio": {
"waveform": padded_waveform,
"sample_rate": self.audio_batch["sampling_rate"],
}
}

# Add text and viseme data only in the first chunk
if i == 0:
if self.text_batch:
chunk_data["text"] = self.text_batch
if viseme_data and len(viseme_data) > 0:
chunk_data["visemes"] = viseme_data
yield chunk_data

# Reset the audio and text buffer after processing
self.audio_batch = {
"waveform": np.array([]),
"sampling_rate": self.expected_sampling_rate,
}
self.text_batch = ""

if self.should_listen_flag:
self.should_listen.set()
self.should_listen_flag = False
Loading

0 comments on commit 1bc8186

Please sign in to comment.