Skip to content

Commit

Permalink
More internal changes merge.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Dec 11, 2024
1 parent cf8b12e commit c3fa343
Showing 1 changed file with 80 additions and 4 deletions.
84 changes: 80 additions & 4 deletions rust/moshi-backend/src/stream_both.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub struct Config {
pub lm_config: Option<moshi::lm_generate_multistream::Config>,
#[serde(default = "default_false")]
pub use_cpu_for_mimi: bool,
pub asr_delay_in_tokens: Option<usize>,
}

fn default_false() -> bool {
Expand Down Expand Up @@ -322,6 +323,62 @@ pub struct StreamingModel {
}

impl StreamingModel {
fn run_with_state_asr(
&self,
state: &mut moshi::lm_generate_multistream::State,
receiver: std::sync::mpsc::Receiver<Vec<f32>>,
sender: tokio::sync::mpsc::UnboundedSender<StreamOut>,
asr_delay_in_tokens: usize,
) -> Result<()> {
use candle::IndexOp;

let app_state = &self.state;

let mut mimi = app_state.mimi_model.clone();
let config = state.config().clone();

mimi.reset_state();
tracing::info!("processing loop");
let mut prev_text_token = config.text_start_token;
let mimi_device =
if self.state.config.use_cpu_for_mimi { &candle::Device::Cpu } else { &self.device };
mimi_device.synchronize()?;
sender.send(StreamOut::Ready)?;
while let Ok(in_pcm) = receiver.recv() {
if in_pcm.is_empty() {
continue;
}
let pcm_len = in_pcm.len();
sender.send(StreamOut::InputPcm { pcm_len })?;
let pcms = candle::Tensor::from_vec(in_pcm, (1, 1, pcm_len), mimi_device)?;
let audio_tokens = mimi.encode_step(&pcms.into())?;
let audio_tokens = match audio_tokens.as_option() {
None => continue,
Some(audio_tokens) => audio_tokens,
};
let (_one, _codebooks, steps) = audio_tokens.dims3()?;

for step in 0..steps {
let codes = audio_tokens.i((0, .., step))?.to_vec1::<u32>()?;
// For the ASR, we don't provide text tokens during the initial steps except the
// initial one.
if state.step_idx() > 0 && state.step_idx() < asr_delay_in_tokens {
prev_text_token = state.step_(None, &codes, None)?;
} else {
sender.send(StreamOut::StepStart { step })?;
let text_token = state.step(prev_text_token, &codes, None)?;
sender.send(StreamOut::StepPostSampling { step })?;
if let Some(text) = app_state.text(prev_text_token, text_token, &config) {
sender.send(StreamOut::Text { text })?;
}
prev_text_token = text_token;
}
}
}
tracing::info!("finished the processing loop");
Ok(())
}

fn run_with_state(
&self,
state: &mut moshi::lm_generate_multistream::State,
Expand Down Expand Up @@ -374,7 +431,6 @@ impl StreamingModel {
sender.send(StreamOut::Pcm { pcm })?;
}
}

if let Some(text) = app_state.text(prev_text_token, text_token, &config) {
sender.send(StreamOut::Text { text })?;
}
Expand Down Expand Up @@ -550,6 +606,8 @@ impl StreamingModel {
// We want to log the output even if the run function returns an error.
let run_result = if self.state.config.use_cpu_for_mimi {
self.run_with_state_mt(&mut state, receiver, sender)
} else if let Some(asr_delay_in_tokens) = self.state.config.asr_delay_in_tokens {
self.run_with_state_asr(&mut state, receiver, sender, asr_delay_in_tokens)
} else {
self.run_with_state(&mut state, receiver, sender)
};
Expand Down Expand Up @@ -577,8 +635,22 @@ impl StreamingModel {
.unwrap_or_else(|_| String::new())
};
let audio_tokens = state.audio_tokens(false);
let audio_tokens = audio_tokens.iter().map(|v| v.as_slice()).collect::<Vec<_>>();
let text_tokens = candle::Tensor::new(text_tokens, &candle::Device::Cpu)?;
let audio_tokens = audio_tokens
.iter()
.map(|v| {
v.iter()
.map(|v| {
if *v == moshi::lm_generate_multistream::UNGENERATED {
-1
} else {
*v as i64
}
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
let text_tokens = candle::Tensor::new(text_tokens, &candle::Device::Cpu)?
.to_dtype(candle::DType::I64)?;
let audio_tokens = candle::Tensor::new(audio_tokens, &candle::Device::Cpu)?;
let since_epoch = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?;
let (secs, us) = (since_epoch.as_secs(), since_epoch.subsec_micros());
Expand Down Expand Up @@ -718,7 +790,11 @@ pub async fn handle_socket(
let (in_pcm_tx, in_pcm_rx) = std::sync::mpsc::channel();
let (stream_out_tx, stream_out_rx) = tokio::sync::mpsc::unbounded_channel();
let (loop1, loop2) = spawn_recv_loops(receiver, in_pcm_tx)?;
std::thread::spawn(move || sm.run(in_pcm_rx, stream_out_tx, addr));
std::thread::spawn(move || {
if let Err(err) = sm.run(in_pcm_rx, stream_out_tx, addr) {
tracing::error!("{err}")
}
});
let sender_loop = tokio::spawn(async move {
match sender_loop(stream_out_rx, sender).await {
Ok(()) => tracing::info!("sender closed"),
Expand Down

0 comments on commit c3fa343

Please sign in to comment.