diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 751f058..b0e1a3c 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -8,7 +8,7 @@ members = [ resolver = "2" [workspace.package] -version = "0.2.4" +version = "0.3.0" edition = "2021" license = "MIT/Apache-2.0" description = "moshi, a real-time voice AI" @@ -18,10 +18,10 @@ categories = ["science"] [workspace.dependencies] -candle = { version = "0.7.2", package = "candle-core" } -candle-nn = "0.7.2" -candle-transformers = "0.7.2" -candle-flash-attn = "0.7.2" +candle = { version = "0.8.1", package = "candle-core" } +candle-nn = "0.8.1" +candle-transformers = "0.8.1" +candle-flash-attn = "0.8.1" [profile.release] debug = true diff --git a/rust/mimi-pyo3/Cargo.toml b/rust/mimi-pyo3/Cargo.toml index 824cbdd..ac4a5ef 100644 --- a/rust/mimi-pyo3/Cargo.toml +++ b/rust/mimi-pyo3/Cargo.toml @@ -16,4 +16,4 @@ crate-type = ["cdylib"] anyhow = "1" numpy = "0.21.0" pyo3 = "0.21.0" -moshi = { path = "../moshi-core", version = "0.2.4" } +moshi = { path = "../moshi-core", version = "0.3.0" } diff --git a/rust/mimi-pyo3/src/lib.rs b/rust/mimi-pyo3/src/lib.rs index 64e9f3f..9e3a2a5 100644 --- a/rust/mimi-pyo3/src/lib.rs +++ b/rust/mimi-pyo3/src/lib.rs @@ -5,7 +5,7 @@ use pyo3::prelude::*; use ::moshi as mm; -use mm::{candle, candle_nn, conv, encodec, seanet, transformer}; +use mm::{candle, candle_nn, conv, mimi, seanet, transformer}; trait PyRes { #[allow(unused)] @@ -39,7 +39,7 @@ macro_rules! py_bail { }; } -fn encodec_cfg(max_seq_len: Option) -> encodec::Config { +fn mimi_cfg(max_seq_len: Option) -> mimi::Config { let seanet_cfg = seanet::Config { dimension: 512, channels: 1, @@ -84,12 +84,12 @@ fn encodec_cfg(max_seq_len: Option) -> encodec::Config { cross_attention: None, max_seq_len: max_seq_len.unwrap_or(8192), // the transformer works at 25hz so this is ~5 mins. }; - encodec::Config { + mimi::Config { channels: 1, sample_rate: 24_000., frame_rate: 12.5, renormalize: true, - resample_method: encodec::ResampleMethod::Conv, + resample_method: mimi::ResampleMethod::Conv, seanet: seanet_cfg, transformer: transformer_cfg, quantizer_n_q: 8, @@ -100,7 +100,7 @@ fn encodec_cfg(max_seq_len: Option) -> encodec::Config { #[pyclass] struct Tokenizer { - encodec: encodec::Encodec, + mimi: mimi::Mimi, device: candle::Device, dtype: candle::DType, } @@ -119,9 +119,9 @@ impl Tokenizer { }; let vb = unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[path], dtype, &device).w()? }; - let cfg = encodec_cfg(max_seq_len); - let encodec = encodec::Encodec::new(cfg, vb).w()?; - Ok(Self { encodec, device, dtype }) + let cfg = mimi_cfg(max_seq_len); + let mimi = mimi::Mimi::new(cfg, vb).w()?; + Ok(Self { mimi, device, dtype }) } fn encode(&mut self, pcm_data: numpy::PyReadonlyArray3) -> PyResult { @@ -136,7 +136,7 @@ impl Tokenizer { .allow_threads(|| { let pcm_data = candle::Tensor::from_slice(pcm_data, pcm_shape, &self.device)? .to_dtype(self.dtype)?; - let codes = self.encodec.encode(&pcm_data)?; + let codes = self.mimi.encode(&pcm_data)?; codes.to_vec3::() }) .w()?; @@ -156,7 +156,7 @@ impl Tokenizer { .allow_threads(|| { let pcm_data = candle::Tensor::from_slice(pcm_data, pcm_shape, &self.device)? .to_dtype(self.dtype)?; - let codes = self.encodec.encode_step(&pcm_data.into())?; + let codes = self.mimi.encode_step(&pcm_data.into())?; match codes.as_option() { Some(codes) => Ok::<_, candle::Error>(Some(codes.to_vec3::()?)), None => Ok(None), @@ -182,7 +182,7 @@ impl Tokenizer { let pcm = py .allow_threads(|| { let codes = candle::Tensor::from_slice(codes, codes_shape, &self.device)?; - let pcm = self.encodec.decode(&codes)?.to_dtype(candle::DType::F32)?; + let pcm = self.mimi.decode(&codes)?.to_dtype(candle::DType::F32)?; pcm.to_vec3::() }) .w()?; @@ -204,7 +204,7 @@ impl Tokenizer { let pcm = py .allow_threads(|| { let codes = candle::Tensor::from_slice(codes, codes_shape, &self.device)?; - let pcm = self.encodec.decode_step(&codes.into())?; + let pcm = self.mimi.decode_step(&codes.into())?; match pcm.as_option() { Some(pcm) => { let pcm = pcm.to_dtype(candle::DType::F32)?; @@ -224,7 +224,7 @@ impl Tokenizer { } fn reset(&mut self) { - self.encodec.reset_state() + self.mimi.reset_state() } } @@ -252,9 +252,9 @@ impl StreamTokenizer { }; let vb = unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[path], dtype, &device).w()? }; - let cfg = encodec_cfg(max_seq_len); - let mut e_encodec = encodec::Encodec::new(cfg, vb).w()?; - let mut d_encodec = e_encodec.clone(); + let cfg = mimi_cfg(max_seq_len); + let mut e_mimi = mimi::Mimi::new(cfg, vb).w()?; + let mut d_mimi = e_mimi.clone(); let (encoder_tx, e_rx) = std::sync::mpsc::channel::>(); let (decoder_tx, d_rx) = std::sync::mpsc::channel::>>(); let (d_tx, decoder_rx) = std::sync::mpsc::channel::>(); @@ -267,7 +267,7 @@ impl StreamTokenizer { let pcm_data = candle::Tensor::from_vec(pcm_data, (1, 1, l), &candle::Device::Cpu)? .to_dtype(dtype)?; - let codes = e_encodec.encode_step(&pcm_data.into())?; + let codes = e_mimi.encode_step(&pcm_data.into())?; if let Some(codes) = codes.as_option() { let mut codes = codes.to_vec3::()?; e_tx.send(codes.remove(0))?; @@ -282,7 +282,7 @@ impl StreamTokenizer { while let Ok(codes) = d_rx.recv() { if let Err(err) = (|| { let codes = candle::Tensor::new(codes, &candle::Device::Cpu)?.unsqueeze(2)?; - let pcm_data = d_encodec.decode_step(&codes.into())?; + let pcm_data = d_mimi.decode_step(&codes.into())?; if let Some(pcm_data) = pcm_data.as_option() { let mut pcm_data = pcm_data.to_vec3::()?; d_tx.send(pcm_data.remove(0).remove(0))?; diff --git a/rust/moshi-backend/Cargo.toml b/rust/moshi-backend/Cargo.toml index b57f003..e6df16c 100644 --- a/rust/moshi-backend/Cargo.toml +++ b/rust/moshi-backend/Cargo.toml @@ -26,7 +26,7 @@ rcgen = "0.13.1" http = "1.1.0" lazy_static = "1.5.0" log = "0.4.20" -moshi = { path = "../moshi-core", version = "0.2.4" } +moshi = { path = "../moshi-core", version = "0.3.0" } ogg = { version = "0.9.1", features = ["async"] } opus = "0.3.0" rand = { version = "0.8.5", features = ["getrandom"] } diff --git a/rust/moshi-backend/config-q8.json b/rust/moshi-backend/config-q8.json index 491c782..6668a1f 100644 --- a/rust/moshi-backend/config-q8.json +++ b/rust/moshi-backend/config-q8.json @@ -4,8 +4,8 @@ "lm_model_file": "$HOME/tmp/moshiko_rs_301e30bf@120/model.q8.gguf", "text_tokenizer_file": "$HOME/tmp/tokenizer_spm_32k_3.model", "log_dir": "$HOME/tmp/moshi-logs", - "encodec_model_file": "$HOME/tmp/tokenizer-e351c8d8-checkpoint125.safetensors", - "encodec_num_codebooks": 8, + "mimi_model_file": "$HOME/tmp/tokenizer-e351c8d8-checkpoint125.safetensors", + "mimi_num_codebooks": 8, "static_dir": "../client/dist", "addr": "0.0.0.0", "port": 8998, diff --git a/rust/moshi-backend/config.json b/rust/moshi-backend/config.json index d23dd85..57f28e2 100644 --- a/rust/moshi-backend/config.json +++ b/rust/moshi-backend/config.json @@ -4,8 +4,8 @@ "lm_model_file": "$HOME/tmp/moshiko_rs_301e30bf@120/model.safetensors", "text_tokenizer_file": "$HOME/tmp/tokenizer_spm_32k_3.model", "log_dir": "$HOME/tmp/moshi-logs", - "encodec_model_file": "$HOME/tmp/tokenizer-e351c8d8-checkpoint125.safetensors", - "encodec_num_codebooks": 8, + "mimi_model_file": "$HOME/tmp/tokenizer-e351c8d8-checkpoint125.safetensors", + "mimi_num_codebooks": 8, "static_dir": "../client/dist", "addr": "0.0.0.0", "port": 8998, diff --git a/rust/moshi-backend/src/benchmark.rs b/rust/moshi-backend/src/benchmark.rs index c44fb62..701bbe0 100644 --- a/rust/moshi-backend/src/benchmark.rs +++ b/rust/moshi-backend/src/benchmark.rs @@ -79,22 +79,21 @@ pub async fn run(args: &crate::BenchmarkArgs, config: &Config) -> Result<()> { }; if args.mimi_only { let device = crate::standalone::device(args.cpu)?; - let encodec_device = - if config.use_cpu_for_encodec { &candle::Device::Cpu } else { &device }; - let mut encodec_model = moshi::encodec::load( - &config.encodec_model_file, - Some(config.encodec_num_codebooks), - encodec_device, + let mimi_device = if config.use_cpu_for_mimi { &candle::Device::Cpu } else { &device }; + let mut mimi_model = moshi::mimi::load( + &config.mimi_model_file, + Some(config.mimi_num_codebooks), + mimi_device, )?; - let config = encodec_model.config(); + let config = mimi_model.config(); let frame_length = (config.sample_rate / config.frame_rate).ceil() as usize; for _step in 0..args.steps { let fake_pcm = - candle::Tensor::zeros((1, 1, frame_length), candle::DType::F32, encodec_device)?; - let codes = encodec_model.encode_step(&fake_pcm.into())?; - let ys = encodec_model.decode_step(&codes)?; + candle::Tensor::zeros((1, 1, frame_length), candle::DType::F32, mimi_device)?; + let codes = mimi_model.encode_step(&fake_pcm.into())?; + let ys = mimi_model.decode_step(&codes)?; if ys.as_option().is_none() { - anyhow::bail!("Expected Encodec to output some stuff, but nothing came out."); + anyhow::bail!("Expected mimi to output some stuff, but nothing came out."); } device.synchronize()?; } diff --git a/rust/moshi-backend/src/standalone.rs b/rust/moshi-backend/src/standalone.rs index 48614fd..56afbec 100644 --- a/rust/moshi-backend/src/standalone.rs +++ b/rust/moshi-backend/src/standalone.rs @@ -29,8 +29,8 @@ impl Config { config.stream.log_dir = crate::utils::replace_env_vars(&config.stream.log_dir); config.stream.text_tokenizer_file = crate::utils::replace_env_vars(&config.stream.text_tokenizer_file); - config.stream.encodec_model_file = - crate::utils::replace_env_vars(&config.stream.encodec_model_file); + config.stream.mimi_model_file = + crate::utils::replace_env_vars(&config.stream.mimi_model_file); config.stream.lm_model_file = crate::utils::replace_env_vars(&config.stream.lm_model_file); Ok(config) } @@ -59,12 +59,11 @@ impl stream_both::AppStateInner { let device = device(args.cpu)?; let dtype = if device.is_cuda() { candle::DType::BF16 } else { candle::DType::F32 }; let lm_model = moshi::lm::load_streaming(&config.lm_model_file, dtype, &device)?; - let encodec_device = - if config.use_cpu_for_encodec { &candle::Device::Cpu } else { &device }; - let encodec_model = moshi::encodec::load( - &config.encodec_model_file, - Some(config.encodec_num_codebooks), - encodec_device, + let mimi_device = if config.use_cpu_for_mimi { &candle::Device::Cpu } else { &device }; + let mimi_model = moshi::mimi::load( + &config.mimi_model_file, + Some(config.mimi_num_codebooks), + mimi_device, )?; let text_tokenizer = sentencepiece::SentencePieceProcessor::open(&config.text_tokenizer_file)?; @@ -72,23 +71,23 @@ impl stream_both::AppStateInner { { tracing::info!(?dtype, ?device, "warming up the model"); let mut lm_model = lm_model.clone(); - let (_v, ys) = lm_model.forward(None, vec![None; config.encodec_num_codebooks])?; + let (_v, ys) = lm_model.forward(None, vec![None; config.mimi_num_codebooks])?; let mut lp = candle_transformers::generation::LogitsProcessor::new(123, None, None); let _ = lm_model.depformer_sample(0, &ys, None, &mut lp)?; - let mut encodec_model = encodec_model.clone(); - let config = encodec_model.config(); + let mut mimi_model = mimi_model.clone(); + let config = mimi_model.config(); let frame_length = (config.sample_rate / config.frame_rate).ceil() as usize; let fake_pcm = - candle::Tensor::zeros((1, 1, frame_length), candle::DType::F32, encodec_device)?; - let codes = encodec_model.encode_step(&fake_pcm.into())?; - let ys = encodec_model.decode_step(&codes)?; + candle::Tensor::zeros((1, 1, frame_length), candle::DType::F32, mimi_device)?; + let codes = mimi_model.encode_step(&fake_pcm.into())?; + let ys = mimi_model.decode_step(&codes)?; if ys.as_option().is_none() { - anyhow::bail!("Expected Encodec to output some stuff, but nothing came out."); + anyhow::bail!("Expected mimi to output some stuff, but nothing came out."); } device.synchronize()?; tracing::info!("model is ready to roll!"); } - Ok(Self { lm_model, encodec_model, device, config: config.clone(), text_tokenizer }) + Ok(Self { lm_model, mimi_model, device, config: config.clone(), text_tokenizer }) } } @@ -121,7 +120,7 @@ pub async fn download_from_hub(config: &mut stream_both::Config) -> Result<()> { .ok_or_else(|| anyhow::anyhow!("'{path}' has no file name")) }; for file_path in - [&mut config.lm_model_file, &mut config.encodec_model_file, &mut config.text_tokenizer_file] + [&mut config.lm_model_file, &mut config.mimi_model_file, &mut config.text_tokenizer_file] .iter_mut() { let filename = extract_filename(file_path) diff --git a/rust/moshi-backend/src/stream_both.rs b/rust/moshi-backend/src/stream_both.rs index de18ff4..9dabcc2 100644 --- a/rust/moshi-backend/src/stream_both.rs +++ b/rust/moshi-backend/src/stream_both.rs @@ -17,11 +17,11 @@ pub struct Config { pub lm_model_file: String, pub log_dir: String, pub text_tokenizer_file: String, - pub encodec_model_file: String, - pub encodec_num_codebooks: usize, + pub mimi_model_file: String, + pub mimi_num_codebooks: usize, pub lm_config: Option, #[serde(default = "default_false")] - pub use_cpu_for_encodec: bool, + pub use_cpu_for_mimi: bool, } fn default_false() -> bool { @@ -34,14 +34,14 @@ impl Config { let mut config: Self = serde_json::from_str(&config)?; config.log_dir = crate::utils::replace_env_vars(&config.log_dir); config.text_tokenizer_file = crate::utils::replace_env_vars(&config.text_tokenizer_file); - config.encodec_model_file = crate::utils::replace_env_vars(&config.encodec_model_file); + config.mimi_model_file = crate::utils::replace_env_vars(&config.mimi_model_file); config.lm_model_file = crate::utils::replace_env_vars(&config.lm_model_file); Ok(config) } /// Check if all modelling files are available on machine. pub fn requires_model_download(&self) -> bool { - [&self.lm_model_file, &self.encodec_model_file, &self.text_tokenizer_file] + [&self.lm_model_file, &self.mimi_model_file, &self.text_tokenizer_file] .iter() .any(|file| !std::path::Path::new(file).exists()) } @@ -50,7 +50,7 @@ impl Config { pub type AppState = Arc; pub struct AppStateInner { pub lm_model: moshi::lm::LmModel, - pub encodec_model: moshi::encodec::Encodec, + pub mimi_model: moshi::mimi::Mimi, pub text_tokenizer: sentencepiece::SentencePieceProcessor, pub device: candle::Device, pub config: Config, @@ -126,7 +126,7 @@ struct SessionSummary<'a> { transcript: String, addr: Option, lm_model_file: &'a str, - encodec_model_file: &'a str, + mimi_model_file: &'a str, #[serde(flatten)] lm_config: &'a Option, } @@ -162,7 +162,7 @@ pub struct MetaData { repetition_penalty_context: usize, repetition_penalty: f32, lm_model_file: String, - encodec_model_file: String, + mimi_model_file: String, build_info: crate::utils::BuildInfo, instance_name: String, } @@ -332,16 +332,16 @@ impl StreamingModel { let app_state = &self.state; - let mut encodec = app_state.encodec_model.clone(); + let mut mimi = app_state.mimi_model.clone(); let config = state.config().clone(); - encodec.reset_state(); + mimi.reset_state(); tracing::info!("processing loop"); let mut prev_text_token = config.text_start_token; let mut tensor_tokens = vec![]; - let encodec_device = - if self.state.config.use_cpu_for_encodec { &candle::Device::Cpu } else { &self.device }; - encodec_device.synchronize()?; + let mimi_device = + if self.state.config.use_cpu_for_mimi { &candle::Device::Cpu } else { &self.device }; + mimi_device.synchronize()?; sender.send(StreamOut::Ready)?; while let Ok(in_pcm) = receiver.recv() { if in_pcm.is_empty() { @@ -349,8 +349,8 @@ impl StreamingModel { } let pcm_len = in_pcm.len(); sender.send(StreamOut::InputPcm { pcm_len })?; - let pcms = candle::Tensor::from_vec(in_pcm, (1, 1, pcm_len), encodec_device)?; - let audio_tokens = encodec.encode_step(&pcms.into())?; + let pcms = candle::Tensor::from_vec(in_pcm, (1, 1, pcm_len), mimi_device)?; + let audio_tokens = mimi.encode_step(&pcms.into())?; let audio_tokens = match audio_tokens.as_option() { None => continue, Some(audio_tokens) => audio_tokens, @@ -364,11 +364,11 @@ impl StreamingModel { sender.send(StreamOut::StepPostSampling { step })?; if let Some(audio_tokens) = state.last_audio_tokens() { let audio_tokens = { - let cb = app_state.config.encodec_num_codebooks; - candle::Tensor::from_slice(&audio_tokens[..cb], (1, cb, 1), encodec_device)? + let cb = app_state.config.mimi_num_codebooks; + candle::Tensor::from_slice(&audio_tokens[..cb], (1, cb, 1), mimi_device)? }; tensor_tokens.push(audio_tokens.clone()); - let pcm = encodec.decode_step(&audio_tokens.into())?; + let pcm = mimi.decode_step(&audio_tokens.into())?; if let Some(pcm) = pcm.as_option() { let pcm = pcm.i((0, 0))?.to_vec1::()?; sender.send(StreamOut::Pcm { pcm })?; @@ -395,10 +395,10 @@ impl StreamingModel { let app_state = &self.state; - let mut encodec = app_state.encodec_model.clone(); + let mut mimi = app_state.mimi_model.clone(); let config = state.config().clone(); - encodec.reset_state(); + mimi.reset_state(); tracing::info!("processing loop"); let mut prev_text_token = config.text_start_token; let mut tensor_tokens = vec![]; @@ -407,7 +407,7 @@ impl StreamingModel { let sender = Arc::new(sender); let status = std::thread::scope(|s| { s.spawn({ - let mut encodec = encodec.clone(); + let mut mimi = mimi.clone(); let sender = sender.clone(); move || { 'outer: while let Ok(in_pcm) = receiver.recv() { @@ -421,7 +421,7 @@ impl StreamingModel { (1, 1, pcm_len), &candle::Device::Cpu, )?; - let audio_tokens = encodec.encode_step(&pcms.into())?; + let audio_tokens = mimi.encode_step(&pcms.into())?; let audio_tokens = match audio_tokens.as_option() { None => continue, Some(audio_tokens) => audio_tokens, @@ -438,7 +438,7 @@ impl StreamingModel { } }); s.spawn({ - let cb = app_state.config.encodec_num_codebooks; + let cb = app_state.config.mimi_num_codebooks; let sender = sender.clone(); move || { while let Ok(audio_tokens) = rx_o.recv() { @@ -450,7 +450,7 @@ impl StreamingModel { )? }; tensor_tokens.push(audio_tokens.clone()); - let pcm = encodec.decode_step(&audio_tokens.into())?; + let pcm = mimi.decode_step(&audio_tokens.into())?; if let Some(pcm) = pcm.as_option() { let pcm = pcm.i((0, 0))?.to_vec1::()?; sender.send(StreamOut::Pcm { pcm })?; @@ -516,7 +516,7 @@ impl StreamingModel { repetition_penalty, repetition_penalty_context, lm_model_file: self.state.config.lm_model_file.to_string(), - encodec_model_file: self.state.config.encodec_model_file.to_string(), + mimi_model_file: self.state.config.mimi_model_file.to_string(), build_info: crate::utils::BuildInfo::new(), instance_name: self.state.config.instance_name.to_string(), }; @@ -547,7 +547,7 @@ impl StreamingModel { ); // We want to log the output even if the run function returns an error. - let run_result = if self.state.config.use_cpu_for_encodec { + let run_result = if self.state.config.use_cpu_for_mimi { self.run_with_state_mt(&mut state, receiver, sender) } else { self.run_with_state(&mut state, receiver, sender) @@ -589,7 +589,7 @@ impl StreamingModel { last_step_idx: state.step_idx(), transcript, addr, - encodec_model_file: &self.state.config.encodec_model_file, + mimi_model_file: &self.state.config.mimi_model_file, lm_model_file: &self.state.config.lm_model_file, lm_config: &self.state.config.lm_config, })?; diff --git a/rust/moshi-core/src/lib.rs b/rust/moshi-core/src/lib.rs index 55744bd..9ec41ea 100644 --- a/rust/moshi-core/src/lib.rs +++ b/rust/moshi-core/src/lib.rs @@ -6,10 +6,10 @@ pub use candle; pub use candle_nn; pub mod conv; -pub mod encodec; pub mod lm; pub mod lm_generate; pub mod lm_generate_multistream; +pub mod mimi; pub mod quantization; pub mod quantized_lm; pub mod quantized_transformer; diff --git a/rust/moshi-core/src/encodec.rs b/rust/moshi-core/src/mimi.rs similarity index 98% rename from rust/moshi-core/src/encodec.rs rename to rust/moshi-core/src/mimi.rs index ba5d51e..45b74cf 100644 --- a/rust/moshi-core/src/encodec.rs +++ b/rust/moshi-core/src/mimi.rs @@ -90,7 +90,7 @@ impl Config { } #[derive(Debug, Clone)] -pub struct Encodec { +pub struct Mimi { encoder: seanet::SeaNetEncoder, decoder: seanet::SeaNetDecoder, encoder_transformer: transformer::ProjectedTransformer, @@ -101,7 +101,7 @@ pub struct Encodec { config: Config, } -impl Encodec { +impl Mimi { pub fn new(cfg: Config, vb: VarBuilder) -> Result { let dim = cfg.seanet.dimension; let encoder = seanet::SeaNetEncoder::new(&cfg.seanet, vb.pp("encoder"))?; @@ -221,10 +221,10 @@ impl Encodec { } } -pub fn load(model_file: &str, num_codebooks: Option, dev: &Device) -> Result { +pub fn load(model_file: &str, num_codebooks: Option, dev: &Device) -> Result { let vb = unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, dev)? }; let cfg = Config::v0_1(num_codebooks); - let encodec = Encodec::new(cfg, vb)?; - Ok(encodec) + let mimi = Mimi::new(cfg, vb)?; + Ok(mimi) } diff --git a/rust/moshi-core/src/tts.rs b/rust/moshi-core/src/tts.rs index 7ee9205..57fe381 100644 --- a/rust/moshi-core/src/tts.rs +++ b/rust/moshi-core/src/tts.rs @@ -9,7 +9,7 @@ use candle_transformers::models::t5; pub struct Config { pub t5: t5::Config, pub lm: crate::lm::Config, - pub encodec: crate::encodec::Config, + pub mimi: crate::mimi::Config, pub max_duration_s: f64, pub speaker_cond_duration_s: f64, pub max_speakers: usize, @@ -18,14 +18,14 @@ pub struct Config { impl Config { pub fn v0_1(t5: t5::Config) -> Self { let lm = crate::lm::Config::tts_v0_1(); - let encodec = crate::encodec::Config::v0_1(None); - Self { t5, lm, encodec, max_duration_s: 60., speaker_cond_duration_s: 4., max_speakers: 5 } + let mimi = crate::mimi::Config::v0_1(None); + Self { t5, lm, mimi, max_duration_s: 60., speaker_cond_duration_s: 4., max_speakers: 5 } } pub fn v0_2(t5: t5::Config) -> Self { let lm = crate::lm::Config::tts_v0_1(); - let encodec = crate::encodec::Config::v0_1(None); - Self { t5, lm, encodec, max_duration_s: 60., speaker_cond_duration_s: 10., max_speakers: 2 } + let mimi = crate::mimi::Config::v0_1(None); + Self { t5, lm, mimi, max_duration_s: 60., speaker_cond_duration_s: 10., max_speakers: 2 } } } @@ -33,7 +33,7 @@ impl Config { pub struct Model { t5: t5::T5EncoderModel, pub lm: crate::lm::Lm, - speaker_cond: Option<(crate::encodec::Encodec, Linear)>, + speaker_cond: Option<(crate::mimi::Mimi, Linear)>, t5_proj: Linear, pub sample_rate: f64, frame_rate: f64, @@ -55,13 +55,13 @@ impl Model { let speaker_cond = match vb_speaker_cond { None => None, Some(vb) => { - let encodec = crate::encodec::Encodec::new(cfg.encodec.clone(), vb)?; + let mimi = crate::mimi::Mimi::new(cfg.mimi.clone(), vb)?; let proj = linear_no_bias( - cfg.encodec.seanet.dimension, + cfg.mimi.seanet.dimension, cfg.lm.transformer.d_model, vb_lm.pp("condition_provider.conditioners.speaker_wavs.output_proj"), )?; - Some((encodec, proj)) + Some((mimi, proj)) } }; let t5_proj = { @@ -78,8 +78,8 @@ impl Model { lm, speaker_cond, t5_proj, - sample_rate: cfg.encodec.sample_rate, - frame_rate: cfg.encodec.frame_rate, + sample_rate: cfg.mimi.sample_rate, + frame_rate: cfg.mimi.frame_rate, audio_vocab_size: cfg.lm.audio_vocab_size as u32, audio_codebooks: cfg.lm.audio_codebooks, max_duration_s: cfg.max_duration_s, @@ -118,7 +118,7 @@ impl Model { Some(speaker_pcm) => { let sc = match self.speaker_cond.as_mut() { None => candle::bail!("speaker_pcm specified without a speaker-cond model"), - Some((encodec, proj)) => encodec + Some((mimi, proj)) => mimi .encode_pre_quantize(speaker_pcm)? .t()? .to_dtype(candle::DType::BF16)?