Skip to content

Commit 67f594d

Browse files
Merge pull request #164 from freddyaboulton/gradio-demo
Add Gradio Demo
2 parents b434b63 + 666b459 commit 67f594d

File tree

3 files changed

+175
-0
lines changed

3 files changed

+175
-0
lines changed

README.md

+11
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ nor does it try to compensate for a growing lag by skipping frames.
158158
Alternatively you can run `python -m moshi_mlx.local_web` to use
159159
the web UI, the connection is via http and will be at [localhost:8998](http://localhost:8998).
160160

161+
161162
## Rust
162163

163164
In order to run the Rust inference server, use the following command from within
@@ -206,6 +207,16 @@ cargo run --bin moshi-cli -r -- tui --host localhost
206207
python -m moshi.client
207208
```
208209

210+
### Gradio Demo
211+
212+
You can launch a Gradio demo locally with the following command:
213+
214+
```bash
215+
python -m moshi.client_gradio --url <moshi-server-url>
216+
```
217+
218+
Prior to running the Gradio demo, please install `gradio-webrtc>=0.0.18`.
219+
209220
### Docker Compose (CUDA only)
210221

211222
```bash

moshi/moshi/client_gradio.py

+163
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import argparse
2+
from typing import Generator, Literal, cast
3+
4+
import numpy as np
5+
import sphn
6+
from numpy.typing import NDArray
7+
8+
try:
9+
import gradio as gr
10+
import websockets.sync.client
11+
from gradio_webrtc import AdditionalOutputs, StreamHandler, WebRTC
12+
except ImportError:
13+
raise ImportError("Please install gradio-webrtc>=0.0.18 to run this script.")
14+
15+
# See https://freddyaboulton.github.io/gradio-webrtc/deployment/ for
16+
# instructions on how to set the rtc_configuration variable for deployment
17+
# on cloud platforms like Heroku, Spaces, etc.
18+
rtc_configuration = None
19+
20+
21+
class MoshiHandler(StreamHandler):
22+
def __init__(
23+
self,
24+
url: str,
25+
expected_layout: Literal["mono", "stereo"] = "mono",
26+
output_sample_rate: int = 24000,
27+
output_frame_size: int = 480,
28+
) -> None:
29+
self.url = url
30+
proto, without_proto = self.url.split("://", 1)
31+
if proto in ["ws", "http"]:
32+
proto = "ws"
33+
elif proto in ["wss", "https"]:
34+
proto = "wss"
35+
36+
self._generator = None
37+
self.output_chunk_size = 1920
38+
self.ws = None
39+
self.ws_url = f"{proto}://{without_proto}/api/chat"
40+
self.stream_reader = sphn.OpusStreamReader(output_sample_rate)
41+
self.stream_writer = sphn.OpusStreamWriter(output_sample_rate)
42+
self.all_output_data = None
43+
super().__init__(
44+
expected_layout,
45+
output_sample_rate,
46+
output_frame_size,
47+
input_sample_rate=24000,
48+
)
49+
50+
def receive(self, frame: tuple[int, NDArray]) -> None:
51+
if not self.ws:
52+
self.ws = websockets.sync.client.connect(self.ws_url)
53+
_, array = frame
54+
array = array.squeeze().astype(np.float32) / 32768.0
55+
self.stream_writer.append_pcm(array)
56+
bytes = b"\x01" + self.stream_writer.read_bytes()
57+
self.ws.send(bytes)
58+
59+
def generator(
60+
self,
61+
) -> Generator[tuple[int, NDArray] | None | AdditionalOutputs, None, None]:
62+
for message in cast(websockets.sync.client.ClientConnection, self.ws):
63+
if len(message) == 0:
64+
yield None
65+
kind = message[0]
66+
if kind == 1:
67+
payload = message[1:]
68+
self.stream_reader.append_bytes(payload)
69+
pcm = self.stream_reader.read_pcm()
70+
if self.all_output_data is None:
71+
self.all_output_data = pcm
72+
else:
73+
self.all_output_data = np.concatenate((self.all_output_data, pcm))
74+
while self.all_output_data.shape[-1] >= self.output_chunk_size:
75+
yield (
76+
self.output_sample_rate,
77+
self.all_output_data[: self.output_chunk_size].reshape(1, -1),
78+
)
79+
self.all_output_data = np.array(
80+
self.all_output_data[self.output_chunk_size :]
81+
)
82+
elif kind == 2:
83+
payload = cast(bytes, message[1:])
84+
yield AdditionalOutputs(payload.decode())
85+
86+
def emit(self) -> tuple[int, NDArray] | AdditionalOutputs | None:
87+
if not self.ws:
88+
return
89+
if not self._generator:
90+
self._generator = self.generator()
91+
try:
92+
return next(self._generator)
93+
except StopIteration:
94+
self.reset()
95+
96+
def reset(self) -> None:
97+
self._generator = None
98+
self.all_output_data = None
99+
100+
def copy(self) -> StreamHandler:
101+
return MoshiHandler(
102+
self.url,
103+
self.expected_layout, # type: ignore
104+
self.output_sample_rate,
105+
self.output_frame_size,
106+
)
107+
108+
def shutdown(self) -> None:
109+
if self.ws:
110+
self.ws.close()
111+
112+
113+
def main():
114+
parser = argparse.ArgumentParser("client_gradio")
115+
parser.add_argument("--url", type=str, help="URL to moshi server.")
116+
args = parser.parse_args()
117+
118+
with gr.Blocks() as demo:
119+
gr.HTML(
120+
"""
121+
<div style='text-align: center'>
122+
<h1>
123+
Talk To Moshi (Powered by WebRTC ⚡️)
124+
</h1>
125+
<p>
126+
Each conversation is limited to 90 seconds. Once the time limit is up you can rejoin the conversation.
127+
</p>
128+
</div>
129+
"""
130+
)
131+
chatbot = gr.Chatbot(type="messages", value=[])
132+
webrtc = WebRTC(
133+
label="Conversation",
134+
modality="audio",
135+
mode="send-receive",
136+
rtc_configuration=rtc_configuration,
137+
)
138+
webrtc.stream(
139+
MoshiHandler(args.url),
140+
inputs=[webrtc, chatbot],
141+
outputs=[webrtc],
142+
time_limit=90,
143+
)
144+
145+
def add_text(chat_history, response):
146+
if len(chat_history) == 0:
147+
chat_history.append({"role": "assistant", "content": ""})
148+
chat_history[-1]["content"] += response
149+
return chat_history
150+
151+
webrtc.on_additional_outputs(
152+
add_text,
153+
inputs=[chatbot],
154+
outputs=chatbot,
155+
queue=False,
156+
show_progress="hidden",
157+
)
158+
159+
demo.launch()
160+
161+
162+
if __name__ == "__main__":
163+
main()

moshi/pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,5 @@ dev = [
3636
"pyright",
3737
"flake8",
3838
"pre-commit",
39+
"gradio-webrtc>=0.0.18"
3940
]

0 commit comments

Comments
 (0)