Skip to content

Commit

Permalink
Merge pull request #167 from kyutai-labs/encodec-rename
Browse files Browse the repository at this point in the history
Rename encodec to mimi.
  • Loading branch information
LaurentMazare authored Dec 10, 2024
2 parents eee03a6 + aa98387 commit 295758d
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 102 deletions.
10 changes: 5 additions & 5 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ members = [
resolver = "2"

[workspace.package]
version = "0.2.4"
version = "0.3.0"
edition = "2021"
license = "MIT/Apache-2.0"
description = "moshi, a real-time voice AI"
Expand All @@ -18,10 +18,10 @@ categories = ["science"]


[workspace.dependencies]
candle = { version = "0.7.2", package = "candle-core" }
candle-nn = "0.7.2"
candle-transformers = "0.7.2"
candle-flash-attn = "0.7.2"
candle = { version = "0.8.1", package = "candle-core" }
candle-nn = "0.8.1"
candle-transformers = "0.8.1"
candle-flash-attn = "0.8.1"

[profile.release]
debug = true
Expand Down
2 changes: 1 addition & 1 deletion rust/mimi-pyo3/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ crate-type = ["cdylib"]
anyhow = "1"
numpy = "0.21.0"
pyo3 = "0.21.0"
moshi = { path = "../moshi-core", version = "0.2.4" }
moshi = { path = "../moshi-core", version = "0.3.0" }
36 changes: 18 additions & 18 deletions rust/mimi-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
use pyo3::prelude::*;

use ::moshi as mm;
use mm::{candle, candle_nn, conv, encodec, seanet, transformer};
use mm::{candle, candle_nn, conv, mimi, seanet, transformer};

trait PyRes<R> {
#[allow(unused)]
Expand Down Expand Up @@ -39,7 +39,7 @@ macro_rules! py_bail {
};
}

fn encodec_cfg(max_seq_len: Option<usize>) -> encodec::Config {
fn mimi_cfg(max_seq_len: Option<usize>) -> mimi::Config {
let seanet_cfg = seanet::Config {
dimension: 512,
channels: 1,
Expand Down Expand Up @@ -84,12 +84,12 @@ fn encodec_cfg(max_seq_len: Option<usize>) -> encodec::Config {
cross_attention: None,
max_seq_len: max_seq_len.unwrap_or(8192), // the transformer works at 25hz so this is ~5 mins.
};
encodec::Config {
mimi::Config {
channels: 1,
sample_rate: 24_000.,
frame_rate: 12.5,
renormalize: true,
resample_method: encodec::ResampleMethod::Conv,
resample_method: mimi::ResampleMethod::Conv,
seanet: seanet_cfg,
transformer: transformer_cfg,
quantizer_n_q: 8,
Expand All @@ -100,7 +100,7 @@ fn encodec_cfg(max_seq_len: Option<usize>) -> encodec::Config {

#[pyclass]
struct Tokenizer {
encodec: encodec::Encodec,
mimi: mimi::Mimi,
device: candle::Device,
dtype: candle::DType,
}
Expand All @@ -119,9 +119,9 @@ impl Tokenizer {
};
let vb =
unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[path], dtype, &device).w()? };
let cfg = encodec_cfg(max_seq_len);
let encodec = encodec::Encodec::new(cfg, vb).w()?;
Ok(Self { encodec, device, dtype })
let cfg = mimi_cfg(max_seq_len);
let mimi = mimi::Mimi::new(cfg, vb).w()?;
Ok(Self { mimi, device, dtype })
}

fn encode(&mut self, pcm_data: numpy::PyReadonlyArray3<f32>) -> PyResult<PyObject> {
Expand All @@ -136,7 +136,7 @@ impl Tokenizer {
.allow_threads(|| {
let pcm_data = candle::Tensor::from_slice(pcm_data, pcm_shape, &self.device)?
.to_dtype(self.dtype)?;
let codes = self.encodec.encode(&pcm_data)?;
let codes = self.mimi.encode(&pcm_data)?;
codes.to_vec3::<u32>()
})
.w()?;
Expand All @@ -156,7 +156,7 @@ impl Tokenizer {
.allow_threads(|| {
let pcm_data = candle::Tensor::from_slice(pcm_data, pcm_shape, &self.device)?
.to_dtype(self.dtype)?;
let codes = self.encodec.encode_step(&pcm_data.into())?;
let codes = self.mimi.encode_step(&pcm_data.into())?;
match codes.as_option() {
Some(codes) => Ok::<_, candle::Error>(Some(codes.to_vec3::<u32>()?)),
None => Ok(None),
Expand All @@ -182,7 +182,7 @@ impl Tokenizer {
let pcm = py
.allow_threads(|| {
let codes = candle::Tensor::from_slice(codes, codes_shape, &self.device)?;
let pcm = self.encodec.decode(&codes)?.to_dtype(candle::DType::F32)?;
let pcm = self.mimi.decode(&codes)?.to_dtype(candle::DType::F32)?;
pcm.to_vec3::<f32>()
})
.w()?;
Expand All @@ -204,7 +204,7 @@ impl Tokenizer {
let pcm = py
.allow_threads(|| {
let codes = candle::Tensor::from_slice(codes, codes_shape, &self.device)?;
let pcm = self.encodec.decode_step(&codes.into())?;
let pcm = self.mimi.decode_step(&codes.into())?;
match pcm.as_option() {
Some(pcm) => {
let pcm = pcm.to_dtype(candle::DType::F32)?;
Expand All @@ -224,7 +224,7 @@ impl Tokenizer {
}

fn reset(&mut self) {
self.encodec.reset_state()
self.mimi.reset_state()
}
}

Expand Down Expand Up @@ -252,9 +252,9 @@ impl StreamTokenizer {
};
let vb =
unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[path], dtype, &device).w()? };
let cfg = encodec_cfg(max_seq_len);
let mut e_encodec = encodec::Encodec::new(cfg, vb).w()?;
let mut d_encodec = e_encodec.clone();
let cfg = mimi_cfg(max_seq_len);
let mut e_mimi = mimi::Mimi::new(cfg, vb).w()?;
let mut d_mimi = e_mimi.clone();
let (encoder_tx, e_rx) = std::sync::mpsc::channel::<Vec<f32>>();
let (decoder_tx, d_rx) = std::sync::mpsc::channel::<Vec<Vec<u32>>>();
let (d_tx, decoder_rx) = std::sync::mpsc::channel::<Vec<f32>>();
Expand All @@ -267,7 +267,7 @@ impl StreamTokenizer {
let pcm_data =
candle::Tensor::from_vec(pcm_data, (1, 1, l), &candle::Device::Cpu)?
.to_dtype(dtype)?;
let codes = e_encodec.encode_step(&pcm_data.into())?;
let codes = e_mimi.encode_step(&pcm_data.into())?;
if let Some(codes) = codes.as_option() {
let mut codes = codes.to_vec3::<u32>()?;
e_tx.send(codes.remove(0))?;
Expand All @@ -282,7 +282,7 @@ impl StreamTokenizer {
while let Ok(codes) = d_rx.recv() {
if let Err(err) = (|| {
let codes = candle::Tensor::new(codes, &candle::Device::Cpu)?.unsqueeze(2)?;
let pcm_data = d_encodec.decode_step(&codes.into())?;
let pcm_data = d_mimi.decode_step(&codes.into())?;
if let Some(pcm_data) = pcm_data.as_option() {
let mut pcm_data = pcm_data.to_vec3::<f32>()?;
d_tx.send(pcm_data.remove(0).remove(0))?;
Expand Down
2 changes: 1 addition & 1 deletion rust/moshi-backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ rcgen = "0.13.1"
http = "1.1.0"
lazy_static = "1.5.0"
log = "0.4.20"
moshi = { path = "../moshi-core", version = "0.2.4" }
moshi = { path = "../moshi-core", version = "0.3.0" }
ogg = { version = "0.9.1", features = ["async"] }
opus = "0.3.0"
rand = { version = "0.8.5", features = ["getrandom"] }
Expand Down
4 changes: 2 additions & 2 deletions rust/moshi-backend/config-q8.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"lm_model_file": "$HOME/tmp/moshiko_rs_301e30bf@120/model.q8.gguf",
"text_tokenizer_file": "$HOME/tmp/tokenizer_spm_32k_3.model",
"log_dir": "$HOME/tmp/moshi-logs",
"encodec_model_file": "$HOME/tmp/tokenizer-e351c8d8-checkpoint125.safetensors",
"encodec_num_codebooks": 8,
"mimi_model_file": "$HOME/tmp/tokenizer-e351c8d8-checkpoint125.safetensors",
"mimi_num_codebooks": 8,
"static_dir": "../client/dist",
"addr": "0.0.0.0",
"port": 8998,
Expand Down
4 changes: 2 additions & 2 deletions rust/moshi-backend/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"lm_model_file": "$HOME/tmp/moshiko_rs_301e30bf@120/model.safetensors",
"text_tokenizer_file": "$HOME/tmp/tokenizer_spm_32k_3.model",
"log_dir": "$HOME/tmp/moshi-logs",
"encodec_model_file": "$HOME/tmp/tokenizer-e351c8d8-checkpoint125.safetensors",
"encodec_num_codebooks": 8,
"mimi_model_file": "$HOME/tmp/tokenizer-e351c8d8-checkpoint125.safetensors",
"mimi_num_codebooks": 8,
"static_dir": "../client/dist",
"addr": "0.0.0.0",
"port": 8998,
Expand Down
21 changes: 10 additions & 11 deletions rust/moshi-backend/src/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,22 +79,21 @@ pub async fn run(args: &crate::BenchmarkArgs, config: &Config) -> Result<()> {
};
if args.mimi_only {
let device = crate::standalone::device(args.cpu)?;
let encodec_device =
if config.use_cpu_for_encodec { &candle::Device::Cpu } else { &device };
let mut encodec_model = moshi::encodec::load(
&config.encodec_model_file,
Some(config.encodec_num_codebooks),
encodec_device,
let mimi_device = if config.use_cpu_for_mimi { &candle::Device::Cpu } else { &device };
let mut mimi_model = moshi::mimi::load(
&config.mimi_model_file,
Some(config.mimi_num_codebooks),
mimi_device,
)?;
let config = encodec_model.config();
let config = mimi_model.config();
let frame_length = (config.sample_rate / config.frame_rate).ceil() as usize;
for _step in 0..args.steps {
let fake_pcm =
candle::Tensor::zeros((1, 1, frame_length), candle::DType::F32, encodec_device)?;
let codes = encodec_model.encode_step(&fake_pcm.into())?;
let ys = encodec_model.decode_step(&codes)?;
candle::Tensor::zeros((1, 1, frame_length), candle::DType::F32, mimi_device)?;
let codes = mimi_model.encode_step(&fake_pcm.into())?;
let ys = mimi_model.decode_step(&codes)?;
if ys.as_option().is_none() {
anyhow::bail!("Expected Encodec to output some stuff, but nothing came out.");
anyhow::bail!("Expected mimi to output some stuff, but nothing came out.");
}
device.synchronize()?;
}
Expand Down
33 changes: 16 additions & 17 deletions rust/moshi-backend/src/standalone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ impl Config {
config.stream.log_dir = crate::utils::replace_env_vars(&config.stream.log_dir);
config.stream.text_tokenizer_file =
crate::utils::replace_env_vars(&config.stream.text_tokenizer_file);
config.stream.encodec_model_file =
crate::utils::replace_env_vars(&config.stream.encodec_model_file);
config.stream.mimi_model_file =
crate::utils::replace_env_vars(&config.stream.mimi_model_file);
config.stream.lm_model_file = crate::utils::replace_env_vars(&config.stream.lm_model_file);
Ok(config)
}
Expand Down Expand Up @@ -59,36 +59,35 @@ impl stream_both::AppStateInner {
let device = device(args.cpu)?;
let dtype = if device.is_cuda() { candle::DType::BF16 } else { candle::DType::F32 };
let lm_model = moshi::lm::load_streaming(&config.lm_model_file, dtype, &device)?;
let encodec_device =
if config.use_cpu_for_encodec { &candle::Device::Cpu } else { &device };
let encodec_model = moshi::encodec::load(
&config.encodec_model_file,
Some(config.encodec_num_codebooks),
encodec_device,
let mimi_device = if config.use_cpu_for_mimi { &candle::Device::Cpu } else { &device };
let mimi_model = moshi::mimi::load(
&config.mimi_model_file,
Some(config.mimi_num_codebooks),
mimi_device,
)?;
let text_tokenizer =
sentencepiece::SentencePieceProcessor::open(&config.text_tokenizer_file)?;
// Warm-up code.
{
tracing::info!(?dtype, ?device, "warming up the model");
let mut lm_model = lm_model.clone();
let (_v, ys) = lm_model.forward(None, vec![None; config.encodec_num_codebooks])?;
let (_v, ys) = lm_model.forward(None, vec![None; config.mimi_num_codebooks])?;
let mut lp = candle_transformers::generation::LogitsProcessor::new(123, None, None);
let _ = lm_model.depformer_sample(0, &ys, None, &mut lp)?;
let mut encodec_model = encodec_model.clone();
let config = encodec_model.config();
let mut mimi_model = mimi_model.clone();
let config = mimi_model.config();
let frame_length = (config.sample_rate / config.frame_rate).ceil() as usize;
let fake_pcm =
candle::Tensor::zeros((1, 1, frame_length), candle::DType::F32, encodec_device)?;
let codes = encodec_model.encode_step(&fake_pcm.into())?;
let ys = encodec_model.decode_step(&codes)?;
candle::Tensor::zeros((1, 1, frame_length), candle::DType::F32, mimi_device)?;
let codes = mimi_model.encode_step(&fake_pcm.into())?;
let ys = mimi_model.decode_step(&codes)?;
if ys.as_option().is_none() {
anyhow::bail!("Expected Encodec to output some stuff, but nothing came out.");
anyhow::bail!("Expected mimi to output some stuff, but nothing came out.");
}
device.synchronize()?;
tracing::info!("model is ready to roll!");
}
Ok(Self { lm_model, encodec_model, device, config: config.clone(), text_tokenizer })
Ok(Self { lm_model, mimi_model, device, config: config.clone(), text_tokenizer })
}
}

Expand Down Expand Up @@ -121,7 +120,7 @@ pub async fn download_from_hub(config: &mut stream_both::Config) -> Result<()> {
.ok_or_else(|| anyhow::anyhow!("'{path}' has no file name"))
};
for file_path in
[&mut config.lm_model_file, &mut config.encodec_model_file, &mut config.text_tokenizer_file]
[&mut config.lm_model_file, &mut config.mimi_model_file, &mut config.text_tokenizer_file]
.iter_mut()
{
let filename = extract_filename(file_path)
Expand Down
Loading

0 comments on commit 295758d

Please sign in to comment.