From c3fa34325c16717ca9b4aa8c30ff29b7cb7fad95 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 11 Dec 2024 08:45:27 +0100 Subject: [PATCH 1/2] More internal changes merge. --- rust/moshi-backend/src/stream_both.rs | 84 +++++++++++++++++++++++++-- 1 file changed, 80 insertions(+), 4 deletions(-) diff --git a/rust/moshi-backend/src/stream_both.rs b/rust/moshi-backend/src/stream_both.rs index 5266f9b..11418a4 100644 --- a/rust/moshi-backend/src/stream_both.rs +++ b/rust/moshi-backend/src/stream_both.rs @@ -22,6 +22,7 @@ pub struct Config { pub lm_config: Option, #[serde(default = "default_false")] pub use_cpu_for_mimi: bool, + pub asr_delay_in_tokens: Option, } fn default_false() -> bool { @@ -322,6 +323,62 @@ pub struct StreamingModel { } impl StreamingModel { + fn run_with_state_asr( + &self, + state: &mut moshi::lm_generate_multistream::State, + receiver: std::sync::mpsc::Receiver>, + sender: tokio::sync::mpsc::UnboundedSender, + asr_delay_in_tokens: usize, + ) -> Result<()> { + use candle::IndexOp; + + let app_state = &self.state; + + let mut mimi = app_state.mimi_model.clone(); + let config = state.config().clone(); + + mimi.reset_state(); + tracing::info!("processing loop"); + let mut prev_text_token = config.text_start_token; + let mimi_device = + if self.state.config.use_cpu_for_mimi { &candle::Device::Cpu } else { &self.device }; + mimi_device.synchronize()?; + sender.send(StreamOut::Ready)?; + while let Ok(in_pcm) = receiver.recv() { + if in_pcm.is_empty() { + continue; + } + let pcm_len = in_pcm.len(); + sender.send(StreamOut::InputPcm { pcm_len })?; + let pcms = candle::Tensor::from_vec(in_pcm, (1, 1, pcm_len), mimi_device)?; + let audio_tokens = mimi.encode_step(&pcms.into())?; + let audio_tokens = match audio_tokens.as_option() { + None => continue, + Some(audio_tokens) => audio_tokens, + }; + let (_one, _codebooks, steps) = audio_tokens.dims3()?; + + for step in 0..steps { + let codes = audio_tokens.i((0, .., step))?.to_vec1::()?; + // For the ASR, we don't provide text tokens during the initial steps except the + // initial one. + if state.step_idx() > 0 && state.step_idx() < asr_delay_in_tokens { + prev_text_token = state.step_(None, &codes, None)?; + } else { + sender.send(StreamOut::StepStart { step })?; + let text_token = state.step(prev_text_token, &codes, None)?; + sender.send(StreamOut::StepPostSampling { step })?; + if let Some(text) = app_state.text(prev_text_token, text_token, &config) { + sender.send(StreamOut::Text { text })?; + } + prev_text_token = text_token; + } + } + } + tracing::info!("finished the processing loop"); + Ok(()) + } + fn run_with_state( &self, state: &mut moshi::lm_generate_multistream::State, @@ -374,7 +431,6 @@ impl StreamingModel { sender.send(StreamOut::Pcm { pcm })?; } } - if let Some(text) = app_state.text(prev_text_token, text_token, &config) { sender.send(StreamOut::Text { text })?; } @@ -550,6 +606,8 @@ impl StreamingModel { // We want to log the output even if the run function returns an error. let run_result = if self.state.config.use_cpu_for_mimi { self.run_with_state_mt(&mut state, receiver, sender) + } else if let Some(asr_delay_in_tokens) = self.state.config.asr_delay_in_tokens { + self.run_with_state_asr(&mut state, receiver, sender, asr_delay_in_tokens) } else { self.run_with_state(&mut state, receiver, sender) }; @@ -577,8 +635,22 @@ impl StreamingModel { .unwrap_or_else(|_| String::new()) }; let audio_tokens = state.audio_tokens(false); - let audio_tokens = audio_tokens.iter().map(|v| v.as_slice()).collect::>(); - let text_tokens = candle::Tensor::new(text_tokens, &candle::Device::Cpu)?; + let audio_tokens = audio_tokens + .iter() + .map(|v| { + v.iter() + .map(|v| { + if *v == moshi::lm_generate_multistream::UNGENERATED { + -1 + } else { + *v as i64 + } + }) + .collect::>() + }) + .collect::>(); + let text_tokens = candle::Tensor::new(text_tokens, &candle::Device::Cpu)? + .to_dtype(candle::DType::I64)?; let audio_tokens = candle::Tensor::new(audio_tokens, &candle::Device::Cpu)?; let since_epoch = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?; let (secs, us) = (since_epoch.as_secs(), since_epoch.subsec_micros()); @@ -718,7 +790,11 @@ pub async fn handle_socket( let (in_pcm_tx, in_pcm_rx) = std::sync::mpsc::channel(); let (stream_out_tx, stream_out_rx) = tokio::sync::mpsc::unbounded_channel(); let (loop1, loop2) = spawn_recv_loops(receiver, in_pcm_tx)?; - std::thread::spawn(move || sm.run(in_pcm_rx, stream_out_tx, addr)); + std::thread::spawn(move || { + if let Err(err) = sm.run(in_pcm_rx, stream_out_tx, addr) { + tracing::error!("{err}") + } + }); let sender_loop = tokio::spawn(async move { match sender_loop(stream_out_rx, sender).await { Ok(()) => tracing::info!("sender closed"), From 0b9f272b454abd835445bcf07c5e504bd9624f4c Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 11 Dec 2024 09:17:13 +0100 Subject: [PATCH 2/2] Merge more changes. --- rust/moshi-backend/Cargo.toml | 8 ++++++++ rust/moshi-backend/src/benchmark.rs | 4 ++-- rust/moshi-backend/src/main.rs | 3 +++ rust/moshi-backend/src/standalone.rs | 2 +- rust/moshi-backend/src/stream_both.rs | 1 + 5 files changed, 15 insertions(+), 3 deletions(-) diff --git a/rust/moshi-backend/Cargo.toml b/rust/moshi-backend/Cargo.toml index e6df16c..f669159 100644 --- a/rust/moshi-backend/Cargo.toml +++ b/rust/moshi-backend/Cargo.toml @@ -56,3 +56,11 @@ vergen = { version = "8.3.1", features = ["build", "cargo", "git", "gitcl", "rus default = [] cuda = ["moshi/cuda", "candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] metal = ["moshi/metal", "candle/metal", "candle-nn/metal", "candle-transformers/metal"] + +[profile.release] +debug = true + +[profile.release-no-debug] +inherits = "release" +debug = false + diff --git a/rust/moshi-backend/src/benchmark.rs b/rust/moshi-backend/src/benchmark.rs index 701bbe0..f6bd202 100644 --- a/rust/moshi-backend/src/benchmark.rs +++ b/rust/moshi-backend/src/benchmark.rs @@ -131,8 +131,8 @@ pub async fn run(args: &crate::BenchmarkArgs, config: &Config) -> Result<()> { tokio::time::sleep_until(target_time).await; in_pcm_tx.send(zeros.to_vec())?; } - let _ = task.await; - let _ = w.await; + task.await?; + w.await??; } } Ok(()) diff --git a/rust/moshi-backend/src/main.rs b/rust/moshi-backend/src/main.rs index 6f88890..42a8879 100644 --- a/rust/moshi-backend/src/main.rs +++ b/rust/moshi-backend/src/main.rs @@ -51,6 +51,9 @@ pub struct BenchmarkArgs { #[clap(long)] chrome_tracing: bool, + #[clap(long)] + asr: bool, + #[clap(long)] mimi_only: bool, } diff --git a/rust/moshi-backend/src/standalone.rs b/rust/moshi-backend/src/standalone.rs index 56afbec..b953366 100644 --- a/rust/moshi-backend/src/standalone.rs +++ b/rust/moshi-backend/src/standalone.rs @@ -82,7 +82,7 @@ impl stream_both::AppStateInner { let codes = mimi_model.encode_step(&fake_pcm.into())?; let ys = mimi_model.decode_step(&codes)?; if ys.as_option().is_none() { - anyhow::bail!("Expected mimi to output some stuff, but nothing came out."); + anyhow::bail!("Expected Mimi to output some stuff, but nothing came out."); } device.synchronize()?; tracing::info!("model is ready to roll!"); diff --git a/rust/moshi-backend/src/stream_both.rs b/rust/moshi-backend/src/stream_both.rs index 11418a4..c983272 100644 --- a/rust/moshi-backend/src/stream_both.rs +++ b/rust/moshi-backend/src/stream_both.rs @@ -13,6 +13,7 @@ use std::sync::Arc; #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct Config { pub instance_name: String, + #[serde(default)] pub hf_repo: String, pub lm_model_file: String, pub log_dir: String,