Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference endpoint #118

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
cccd7c1
Add all the code necessary for the inference endpoint.
andimarafioti Sep 11, 2024
2d55243
adding debugs
andimarafioti Sep 11, 2024
4b54c18
try with device
andimarafioti Sep 11, 2024
b3bca23
remove debug
andimarafioti Sep 11, 2024
44ab522
lower the volume
andimarafioti Sep 11, 2024
b402d96
push
andimarafioti Sep 11, 2024
684c748
first pass
andimarafioti Sep 19, 2024
adda8e2
remove device cpu
andimarafioti Sep 19, 2024
32d58a7
download unidic
andimarafioti Sep 19, 2024
caf50ee
debug
andimarafioti Sep 19, 2024
98c5d13
saving working state, just for text
andimarafioti Sep 20, 2024
b40558a
changes
andimarafioti Sep 20, 2024
d2d9182
updates client
andimarafioti Sep 24, 2024
9de0faa
rename handler
andimarafioti Sep 24, 2024
a9d3ead
add a bit of debug and sleep more correctly
andimarafioti Sep 24, 2024
4fd02d6
trying websockets
andimarafioti Sep 24, 2024
ea14a21
update done
andimarafioti Sep 24, 2024
3f28212
yield done when done
andimarafioti Sep 24, 2024
97aa3c2
improve client
andimarafioti Sep 24, 2024
b6a1d16
adapt melo
andimarafioti Sep 26, 2024
69057b7
remove unnecessary stuff
andimarafioti Sep 26, 2024
ffb2731
idea
andimarafioti Sep 27, 2024
dc4cd11
remove logging
andimarafioti Sep 27, 2024
66533a2
simplifications to the audio part
andimarafioti Sep 27, 2024
75b364b
dont let the thread die
andimarafioti Oct 21, 2024
0b5bc66
improvements
andimarafioti Oct 22, 2024
e9fb909
improvement
andimarafioti Oct 22, 2024
a40bf62
fix
andimarafioti Oct 22, 2024
e5b1c0f
few fixes
andimarafioti Oct 22, 2024
9e156c8
revert changes to listen and play
andimarafioti Oct 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions LLM/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,4 @@ def process(self, prompt):

# don't forget last sentence
yield (printable_text, language_code)
yield b"DONE"
14 changes: 10 additions & 4 deletions TTS/melo_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def warmup(self):
_ = self.model.tts_to_file("text", self.speaker_id, quiet=True)

def process(self, llm_sentence):
if llm_sentence == b"DONE":
self.should_listen.set()
yield b"DONE"
return

language_code = None

if isinstance(llm_sentence, tuple):
Expand Down Expand Up @@ -94,10 +99,11 @@ def process(self, llm_sentence):
)
except (AssertionError, RuntimeError) as e:
logger.error(f"Error in MeloTTSHandler: {e}")
audio_chunk = np.array([])
if len(audio_chunk) == 0:
self.should_listen.set()
return
audio_chunk = np.zeros([self.blocksize])
except Exception as e:
logger.error(f"Unknown error in MeloTTSHandler: {e}")
audio_chunk = np.zeros([self.blocksize])

audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)
for i in range(0, len(audio_chunk), self.blocksize):
Expand Down
1 change: 1 addition & 0 deletions TTS/parler_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,4 @@ def process(self, llm_sentence):
)

self.should_listen.set()
yield b"END"
2 changes: 1 addition & 1 deletion arguments_classes/module_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class ModuleArguments:
mode: Optional[str] = field(
default="socket",
metadata={
"help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'socket'."
"help": "The mode to run the pipeline in. Either 'local', 'socket', or 'none'. Default is 'socket'."
},
)
local_mac_optimal_settings: bool = field(
Expand Down
185 changes: 185 additions & 0 deletions audio_streaming_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import threading
from queue import Queue
import sounddevice as sd
import numpy as np
import time
from dataclasses import dataclass, field
import websocket
import ssl


@dataclass
class AudioStreamingClientArguments:
sample_rate: int = field(
default=16000, metadata={"help": "Audio sample rate in Hz. Default is 16000."}
)
chunk_size: int = field(
default=512,
metadata={"help": "The size of audio chunks in samples. Default is 512."},
)
api_url: str = field(
default="https://yxfmjcvuzgi123sw.us-east-1.aws.endpoints.huggingface.cloud",
metadata={"help": "The URL of the API endpoint."},
)
auth_token: str = field(
default="your_auth_token",
metadata={"help": "Authentication token for the API."},
)


class AudioStreamingClient:
def __init__(self, args: AudioStreamingClientArguments):
self.args = args
self.stop_event = threading.Event()
self.send_queue = Queue()
self.recv_queue = Queue()
self.session_id = None
self.headers = {
"Accept": "application/json",
"Authorization": f"Bearer {self.args.auth_token}",
"Content-Type": "application/json",
}
self.session_state = (
"idle" # Possible states: idle, sending, processing, waiting
)
self.ws_ready = threading.Event()

def start(self):
print("Starting audio streaming...")

ws_url = self.args.api_url.replace("http", "ws") + "/ws"

self.ws = websocket.WebSocketApp(
ws_url,
header=[f"{key}: {value}" for key, value in self.headers.items()],
on_open=self.on_open,
on_message=self.on_message,
on_error=self.on_error,
on_close=self.on_close,
)

self.ws_thread = threading.Thread(
target=self.ws.run_forever, kwargs={"sslopt": {"cert_reqs": ssl.CERT_NONE}}
)
self.ws_thread.start()

# Wait for the WebSocket to be ready
self.ws_ready.wait()
self.start_audio_streaming()

def start_audio_streaming(self):
self.send_thread = threading.Thread(target=self.send_audio)
self.play_thread = threading.Thread(target=self.play_audio)

with sd.InputStream(
samplerate=self.args.sample_rate,
channels=1,
dtype="int16",
callback=self.audio_input_callback,
blocksize=self.args.chunk_size,
):
self.send_thread.start()
self.play_thread.start()
input("Press Enter to stop streaming... \n")
self.on_shutdown()

def on_open(self, ws):
print("WebSocket connection opened.")
self.ws_ready.set() # Signal that the WebSocket is ready

def on_message(self, ws, message):
# message is bytes
if message == b"DONE":
print("listen")
self.session_state = "listen"
else:
if self.session_state != "processing":
print("processing")
self.session_state = "processing"
audio_np = np.frombuffer(message, dtype=np.int16)
self.recv_queue.put(audio_np)

def on_error(self, ws, error):
print(f"WebSocket error: {error}")

def on_close(self, ws, close_status_code, close_msg):
print("WebSocket connection closed.")

def on_shutdown(self):
self.stop_event.set()
self.send_thread.join()
self.play_thread.join()
self.ws.close()
self.ws_thread.join()
print("Service shutdown.")

def send_audio(self):
while not self.stop_event.is_set():
if not self.send_queue.empty():
chunk = self.send_queue.get()
if self.session_state != "processing":
self.ws.send(chunk.tobytes(), opcode=websocket.ABNF.OPCODE_BINARY)
else:
self.ws.send([], opcode=websocket.ABNF.OPCODE_BINARY) # handshake
time.sleep(0.01)

def audio_input_callback(self, indata, frames, time, status):
self.send_queue.put(indata.copy())

def audio_out_callback(self, outdata, frames, time, status):
if not self.recv_queue.empty():
chunk = self.recv_queue.get()

# Ensure chunk is int16 and clip to valid range
chunk_int16 = np.clip(chunk, -32768, 32767).astype(np.int16)

if len(chunk_int16) < len(outdata):
outdata[: len(chunk_int16), 0] = chunk_int16
outdata[len(chunk_int16) :] = 0
else:
outdata[:, 0] = chunk_int16[: len(outdata)]
else:
outdata[:] = 0

def play_audio(self):
with sd.OutputStream(
samplerate=self.args.sample_rate,
channels=1,
dtype="int16",
callback=self.audio_out_callback,
blocksize=self.args.chunk_size,
):
while not self.stop_event.is_set():
time.sleep(0.1)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Audio Streaming Client")
parser.add_argument(
"--sample_rate",
type=int,
default=16000,
help="Audio sample rate in Hz. Default is 16000.",
)
parser.add_argument(
"--chunk_size",
type=int,
default=1024,
help="The size of audio chunks in samples. Default is 1024.",
)
parser.add_argument(
"--api_url", type=str, required=True, help="The URL of the API endpoint."
)
parser.add_argument(
"--auth_token",
type=str,
required=True,
help="Authentication token for the API.",
)

args = parser.parse_args()
client_args = AudioStreamingClientArguments(**vars(args))
client = AudioStreamingClient(client_args)
client.start()
2 changes: 1 addition & 1 deletion listen_and_play.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,4 @@ def receive_full_chunk(conn, chunk_size):
if __name__ == "__main__":
parser = HfArgumentParser((ListenAndPlayArguments,))
(listen_and_play_kwargs,) = parser.parse_args_into_dataclasses()
listen_and_play(**vars(listen_and_play_kwargs))
listen_and_play(**vars(listen_and_play_kwargs))
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ funasr>=1.1.6
faster-whisper>=1.0.3
modelscope>=1.17.1
deepfilternet>=0.5.6
openai>=1.40.1
openai>=1.40.1
websocket-client>=1.8.0
3 changes: 2 additions & 1 deletion requirements_mac.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ funasr>=1.1.6
faster-whisper>=1.0.3
modelscope>=1.17.1
deepfilternet>=0.5.6
openai>=1.40.1
openai>=1.40.1
websocket-client>=1.8.0
105 changes: 105 additions & 0 deletions s2s_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Dict, Any, List, Generator
import torch
import os
import logging
from s2s_pipeline import main, prepare_all_args, get_default_arguments, setup_logger, initialize_queues_and_events, build_pipeline
import numpy as np
from queue import Queue, Empty
import threading
import base64
import uuid
import torch

class EndpointHandler:
def __init__(self, path=""):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lm_model_name = os.getenv('LM_MODEL_NAME', 'meta-llama/Meta-Llama-3.1-8B-Instruct')
chat_size = int(os.getenv('CHAT_SIZE', 10))

(
self.module_kwargs,
self.socket_receiver_kwargs,
self.socket_sender_kwargs,
self.vad_handler_kwargs,
self.whisper_stt_handler_kwargs,
self.paraformer_stt_handler_kwargs,
self.faster_whisper_stt_handler_kwargs,
self.language_model_handler_kwargs,
self.open_api_language_model_handler_kwargs,
self.mlx_language_model_handler_kwargs,
self.parler_tts_handler_kwargs,
self.melo_tts_handler_kwargs,
self.chat_tts_handler_kwargs,
self.facebook_mm_stts_handler_kwargs,
) = get_default_arguments(mode='none', log_level='DEBUG', lm_model_name=lm_model_name,
tts="melo", device=device, chat_size=chat_size)
setup_logger(self.module_kwargs.log_level)

prepare_all_args(
self.module_kwargs,
self.whisper_stt_handler_kwargs,
self.paraformer_stt_handler_kwargs,
self.faster_whisper_stt_handler_kwargs,
self.language_model_handler_kwargs,
self.open_api_language_model_handler_kwargs,
self.mlx_language_model_handler_kwargs,
self.parler_tts_handler_kwargs,
self.melo_tts_handler_kwargs,
self.chat_tts_handler_kwargs,
self.facebook_mm_stts_handler_kwargs,
)

self.queues_and_events = initialize_queues_and_events()

self.pipeline_manager = build_pipeline(
self.module_kwargs,
self.socket_receiver_kwargs,
self.socket_sender_kwargs,
self.vad_handler_kwargs,
self.whisper_stt_handler_kwargs,
self.paraformer_stt_handler_kwargs,
self.faster_whisper_stt_handler_kwargs,
self.language_model_handler_kwargs,
self.open_api_language_model_handler_kwargs,
self.mlx_language_model_handler_kwargs,
self.parler_tts_handler_kwargs,
self.melo_tts_handler_kwargs,
self.chat_tts_handler_kwargs,
self.facebook_mm_stts_handler_kwargs,
self.queues_and_events,
)

self.vad_chunk_size = 512 # Set the chunk size required by the VAD model
self.sample_rate = 16000 # Set the expected sample rate

def process_streaming_data(self, data: bytes) -> bytes:
audio_array = np.frombuffer(data, dtype=np.int16)

# Process the audio data in chunks
chunks = [audio_array[i:i+self.vad_chunk_size] for i in range(0, len(audio_array), self.vad_chunk_size)]

for chunk in chunks:
if len(chunk) == self.vad_chunk_size:
self.queues_and_events['recv_audio_chunks_queue'].put(chunk.tobytes())
elif len(chunk) < self.vad_chunk_size:
# Pad the last chunk if it's smaller than the required size
padded_chunk = np.pad(chunk, (0, self.vad_chunk_size - len(chunk)), 'constant')
self.queues_and_events['recv_audio_chunks_queue'].put(padded_chunk.tobytes())

# Collect the output, if any
try:
output = self.queues_and_events['send_audio_chunks_queue'].get_nowait() # improvement idea, group all available output chunks
if isinstance(output, np.ndarray):
return output.tobytes()
else:
return output
except Empty:
return None

def cleanup(self):
# Stop the pipeline
self.pipeline_manager.stop()

# Stop the output collector thread
self.queues_and_events['send_audio_chunks_queue'].put(b"END")
self.output_collector_thread.join()
Loading