Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More internal changes merge. #171

Merged
merged 2 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions rust/moshi-backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,11 @@ vergen = { version = "8.3.1", features = ["build", "cargo", "git", "gitcl", "rus
default = []
cuda = ["moshi/cuda", "candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
metal = ["moshi/metal", "candle/metal", "candle-nn/metal", "candle-transformers/metal"]

[profile.release]
debug = true

[profile.release-no-debug]
inherits = "release"
debug = false

4 changes: 2 additions & 2 deletions rust/moshi-backend/src/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ pub async fn run(args: &crate::BenchmarkArgs, config: &Config) -> Result<()> {
tokio::time::sleep_until(target_time).await;
in_pcm_tx.send(zeros.to_vec())?;
}
let _ = task.await;
let _ = w.await;
task.await?;
w.await??;
}
}
Ok(())
Expand Down
3 changes: 3 additions & 0 deletions rust/moshi-backend/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ pub struct BenchmarkArgs {
#[clap(long)]
chrome_tracing: bool,

#[clap(long)]
asr: bool,

#[clap(long)]
mimi_only: bool,
}
Expand Down
2 changes: 1 addition & 1 deletion rust/moshi-backend/src/standalone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl stream_both::AppStateInner {
let codes = mimi_model.encode_step(&fake_pcm.into())?;
let ys = mimi_model.decode_step(&codes)?;
if ys.as_option().is_none() {
anyhow::bail!("Expected mimi to output some stuff, but nothing came out.");
anyhow::bail!("Expected Mimi to output some stuff, but nothing came out.");
}
device.synchronize()?;
tracing::info!("model is ready to roll!");
Expand Down
85 changes: 81 additions & 4 deletions rust/moshi-backend/src/stream_both.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use std::sync::Arc;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Config {
pub instance_name: String,
#[serde(default)]
pub hf_repo: String,
pub lm_model_file: String,
pub log_dir: String,
Expand All @@ -22,6 +23,7 @@ pub struct Config {
pub lm_config: Option<moshi::lm_generate_multistream::Config>,
#[serde(default = "default_false")]
pub use_cpu_for_mimi: bool,
pub asr_delay_in_tokens: Option<usize>,
}

fn default_false() -> bool {
Expand Down Expand Up @@ -322,6 +324,62 @@ pub struct StreamingModel {
}

impl StreamingModel {
fn run_with_state_asr(
&self,
state: &mut moshi::lm_generate_multistream::State,
receiver: std::sync::mpsc::Receiver<Vec<f32>>,
sender: tokio::sync::mpsc::UnboundedSender<StreamOut>,
asr_delay_in_tokens: usize,
) -> Result<()> {
use candle::IndexOp;

let app_state = &self.state;

let mut mimi = app_state.mimi_model.clone();
let config = state.config().clone();

mimi.reset_state();
tracing::info!("processing loop");
let mut prev_text_token = config.text_start_token;
let mimi_device =
if self.state.config.use_cpu_for_mimi { &candle::Device::Cpu } else { &self.device };
mimi_device.synchronize()?;
sender.send(StreamOut::Ready)?;
while let Ok(in_pcm) = receiver.recv() {
if in_pcm.is_empty() {
continue;
}
let pcm_len = in_pcm.len();
sender.send(StreamOut::InputPcm { pcm_len })?;
let pcms = candle::Tensor::from_vec(in_pcm, (1, 1, pcm_len), mimi_device)?;
let audio_tokens = mimi.encode_step(&pcms.into())?;
let audio_tokens = match audio_tokens.as_option() {
None => continue,
Some(audio_tokens) => audio_tokens,
};
let (_one, _codebooks, steps) = audio_tokens.dims3()?;

for step in 0..steps {
let codes = audio_tokens.i((0, .., step))?.to_vec1::<u32>()?;
// For the ASR, we don't provide text tokens during the initial steps except the
// initial one.
if state.step_idx() > 0 && state.step_idx() < asr_delay_in_tokens {
prev_text_token = state.step_(None, &codes, None)?;
} else {
sender.send(StreamOut::StepStart { step })?;
let text_token = state.step(prev_text_token, &codes, None)?;
sender.send(StreamOut::StepPostSampling { step })?;
if let Some(text) = app_state.text(prev_text_token, text_token, &config) {
sender.send(StreamOut::Text { text })?;
}
prev_text_token = text_token;
}
}
}
tracing::info!("finished the processing loop");
Ok(())
}

fn run_with_state(
&self,
state: &mut moshi::lm_generate_multistream::State,
Expand Down Expand Up @@ -374,7 +432,6 @@ impl StreamingModel {
sender.send(StreamOut::Pcm { pcm })?;
}
}

if let Some(text) = app_state.text(prev_text_token, text_token, &config) {
sender.send(StreamOut::Text { text })?;
}
Expand Down Expand Up @@ -550,6 +607,8 @@ impl StreamingModel {
// We want to log the output even if the run function returns an error.
let run_result = if self.state.config.use_cpu_for_mimi {
self.run_with_state_mt(&mut state, receiver, sender)
} else if let Some(asr_delay_in_tokens) = self.state.config.asr_delay_in_tokens {
self.run_with_state_asr(&mut state, receiver, sender, asr_delay_in_tokens)
} else {
self.run_with_state(&mut state, receiver, sender)
};
Expand Down Expand Up @@ -577,8 +636,22 @@ impl StreamingModel {
.unwrap_or_else(|_| String::new())
};
let audio_tokens = state.audio_tokens(false);
let audio_tokens = audio_tokens.iter().map(|v| v.as_slice()).collect::<Vec<_>>();
let text_tokens = candle::Tensor::new(text_tokens, &candle::Device::Cpu)?;
let audio_tokens = audio_tokens
.iter()
.map(|v| {
v.iter()
.map(|v| {
if *v == moshi::lm_generate_multistream::UNGENERATED {
-1
} else {
*v as i64
}
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
let text_tokens = candle::Tensor::new(text_tokens, &candle::Device::Cpu)?
.to_dtype(candle::DType::I64)?;
let audio_tokens = candle::Tensor::new(audio_tokens, &candle::Device::Cpu)?;
let since_epoch = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?;
let (secs, us) = (since_epoch.as_secs(), since_epoch.subsec_micros());
Expand Down Expand Up @@ -718,7 +791,11 @@ pub async fn handle_socket(
let (in_pcm_tx, in_pcm_rx) = std::sync::mpsc::channel();
let (stream_out_tx, stream_out_rx) = tokio::sync::mpsc::unbounded_channel();
let (loop1, loop2) = spawn_recv_loops(receiver, in_pcm_tx)?;
std::thread::spawn(move || sm.run(in_pcm_rx, stream_out_tx, addr));
std::thread::spawn(move || {
if let Err(err) = sm.run(in_pcm_rx, stream_out_tx, addr) {
tracing::error!("{err}")
}
});
let sender_loop = tokio::spawn(async move {
match sender_loop(stream_out_rx, sender).await {
Ok(()) => tracing::info!("sender closed"),
Expand Down
Loading