diff --git a/rust/mimi-pyo3/src/lib.rs b/rust/mimi-pyo3/src/lib.rs index ee696aa..64e9f3f 100644 --- a/rust/mimi-pyo3/src/lib.rs +++ b/rust/mimi-pyo3/src/lib.rs @@ -81,7 +81,7 @@ fn encodec_cfg(max_seq_len: Option) -> encodec::Config { dim_feedforward: 2048, kv_repeat: 1, conv_layout: true, // see builders.py - cross_attention: false, + cross_attention: None, max_seq_len: max_seq_len.unwrap_or(8192), // the transformer works at 25hz so this is ~5 mins. }; encodec::Config { diff --git a/rust/moshi-core/src/encodec.rs b/rust/moshi-core/src/encodec.rs index e889625..ba5d51e 100644 --- a/rust/moshi-core/src/encodec.rs +++ b/rust/moshi-core/src/encodec.rs @@ -63,7 +63,7 @@ impl Config { conv_kernel_size: 5, use_conv_bias: true, use_conv_block: false, - cross_attention: false, + cross_attention: None, max_period: 10000, gating: None, norm: crate::NormType::LayerNorm, diff --git a/rust/moshi-core/src/lm.rs b/rust/moshi-core/src/lm.rs index 4cd0be5..5fd286d 100644 --- a/rust/moshi-core/src/lm.rs +++ b/rust/moshi-core/src/lm.rs @@ -53,7 +53,7 @@ impl Config { max_period: 10000, use_conv_block: false, use_conv_bias: true, - cross_attention: false, + cross_attention: None, gating: Some(candle_nn::Activation::Silu), norm: crate::NormType::RmsNorm, positional_embedding: transformer::PositionalEmbedding::Rope, @@ -76,7 +76,7 @@ impl Config { max_period: 10000, use_conv_block: false, use_conv_bias: true, - cross_attention: false, + cross_attention: None, gating: Some(candle_nn::Activation::Silu), norm: crate::NormType::RmsNorm, positional_embedding: transformer::PositionalEmbedding::None, @@ -122,7 +122,7 @@ impl Config { max_period: 10000, use_conv_block: false, use_conv_bias: true, - cross_attention: true, + cross_attention: None, gating: None, norm: crate::NormType::LayerNorm, positional_embedding: transformer::PositionalEmbedding::Rope, @@ -145,7 +145,7 @@ impl Config { max_period: 10000, use_conv_block: false, use_conv_bias: true, - cross_attention: false, + cross_attention: None, gating: None, norm: crate::NormType::LayerNorm, positional_embedding: transformer::PositionalEmbedding::Sin, diff --git a/rust/moshi-core/src/quantized_transformer.rs b/rust/moshi-core/src/quantized_transformer.rs index c5db7b5..b35f1e5 100644 --- a/rust/moshi-core/src/quantized_transformer.rs +++ b/rust/moshi-core/src/quantized_transformer.rs @@ -3,7 +3,7 @@ // LICENSE file in the root directory of this source tree. use crate::streaming::{StreamTensor, StreamingModule}; -use crate::transformer::{get_mask, PositionalEmbedding, RotaryEmbedding}; +use crate::transformer::{get_mask, CrossAttention, PositionalEmbedding, RotaryEmbedding}; use candle::{DType, IndexOp, Module, Result, Tensor, D}; use candle_transformers::quantized_nn::{layer_norm, linear_b, Linear}; @@ -169,7 +169,7 @@ pub struct StreamingMultiheadCrossAttention { } impl StreamingMultiheadCrossAttention { - pub fn new(_cfg: &Config, _vb: VarBuilder) -> Result { + pub fn new(_ca: CrossAttention, _cfg: &Config, _vb: VarBuilder) -> Result { candle::bail!("cross-attn is not supported at the moment") } @@ -347,12 +347,14 @@ impl StreamingTransformerLayer { } }; let self_attn = StreamingMultiheadAttention::new(rope, cfg, vb.pp("self_attn"))?; - let cross_attn = if cfg.cross_attention { - let norm_cross = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_cross"))?; - let cross_attn = StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?; - Some((norm_cross, cross_attn)) - } else { - None + let cross_attn = match cfg.cross_attention { + Some(ca) => { + let norm_cross = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_cross"))?; + let cross_attn = + StreamingMultiheadCrossAttention::new(ca, cfg, vb.pp("cross_attention"))?; + Some((norm_cross, cross_attn)) + } + None => None, }; Ok(Self { self_attn, diff --git a/rust/moshi-core/src/transformer.rs b/rust/moshi-core/src/transformer.rs index a2b7f33..e4c4b2b 100644 --- a/rust/moshi-core/src/transformer.rs +++ b/rust/moshi-core/src/transformer.rs @@ -22,6 +22,12 @@ pub enum PositionalEmbedding { None, } +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum CrossAttention { + Normal, + Gated, +} + #[derive(Debug, Clone)] pub struct Config { pub d_model: usize, @@ -34,7 +40,7 @@ pub struct Config { pub layer_scale: Option, pub positional_embedding: PositionalEmbedding, pub use_conv_block: bool, - pub cross_attention: bool, + pub cross_attention: Option, pub conv_kernel_size: usize, pub use_conv_bias: bool, pub gating: Option, @@ -242,11 +248,12 @@ pub struct StreamingMultiheadCrossAttention { kv_repeat: usize, num_heads: usize, neg_inf: Tensor, + tanh_gate_alpha: Option, span: tracing::Span, } impl StreamingMultiheadCrossAttention { - pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + pub fn new(ca: CrossAttention, cfg: &Config, vb: VarBuilder) -> Result { let embed_dim = cfg.d_model; let num_kv = cfg.num_heads / cfg.kv_repeat; let kv_dim = num_kv * (embed_dim / cfg.num_heads); @@ -269,6 +276,10 @@ impl StreamingMultiheadCrossAttention { let in_proj_v = Linear::new(in_proj_weight_v, in_proj_bias_v); let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?; let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?; + let tanh_gate_alpha = match ca { + CrossAttention::Gated => Some(vb.get((1, 1, 1), "tanh_gate.alpha")?.tanh()?), + CrossAttention::Normal => None, + }; Ok(Self { in_proj_q, in_proj_k, @@ -277,6 +288,7 @@ impl StreamingMultiheadCrossAttention { kv_repeat: cfg.kv_repeat, num_heads: cfg.num_heads, neg_inf, + tanh_gate_alpha, span: tracing::span!(tracing::Level::TRACE, "mhca"), }) } @@ -320,39 +332,28 @@ impl StreamingMultiheadCrossAttention { .transpose(1, 2)? // b,t,h,d .reshape((b, t, hd))? .apply(&self.out_proj)?; + let xs = match self.tanh_gate_alpha.as_ref() { + None => xs, + Some(alpha) => xs.broadcast_mul(alpha)?, + }; Ok(xs) } } #[derive(Debug, Clone)] pub enum Mlp { - NoGating { - span1: tracing::Span, - linear1: Linear, - span2: tracing::Span, - linear2: Linear, - span: tracing::Span, - }, - Gating { - linear_in: Linear, - linear_out: Linear, - activation: candle_nn::Activation, - span: tracing::Span, - }, + NoGating { linear1: Linear, linear2: Linear }, + Gating { linear_in: Linear, linear_out: Linear, activation: candle_nn::Activation }, } impl Mlp { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let d_model = cfg.d_model; - let span = tracing::span!(tracing::Level::TRACE, "mlp"); - match cfg.gating { None => { - let span1 = tracing::span!(tracing::Level::TRACE, "lin1"); - let span2 = tracing::span!(tracing::Level::TRACE, "lin2"); let linear1 = linear(d_model, cfg.dim_feedforward, cfg.bias_ff, vb.pp("linear1"))?; let linear2 = linear(cfg.dim_feedforward, d_model, cfg.bias_ff, vb.pp("linear2"))?; - Ok(Self::NoGating { linear1, linear2, span, span1, span2 }) + Ok(Self::NoGating { linear1, linear2 }) } Some(activation) => { let vb = vb.pp("gating"); @@ -364,7 +365,7 @@ impl Mlp { // TODO: Maybe use bias_ff here? let linear_in = linear(d_model, 2 * hidden, false, vb.pp("linear_in"))?; let linear_out = linear(hidden, d_model, false, vb.pp("linear_out"))?; - Ok(Self::Gating { linear_in, linear_out, activation, span }) + Ok(Self::Gating { linear_in, linear_out, activation }) } } } @@ -373,20 +374,8 @@ impl Mlp { impl Module for Mlp { fn forward(&self, xs: &Tensor) -> Result { match self { - Self::NoGating { linear1, linear2, span, span1, span2 } => { - let _enter = span.enter(); - let xs = { - let _enter = span1.enter(); - xs.apply(linear1)? - }; - let xs = xs.gelu_erf()?; - { - let _enter = span2.enter(); - xs.apply(linear2) - } - } - Self::Gating { linear_in, linear_out, activation, span } => { - let _enter = span.enter(); + Self::NoGating { linear1, linear2 } => xs.apply(linear1)?.gelu_erf()?.apply(linear2), + Self::Gating { linear_in, linear_out, activation } => { let xs = xs.apply(linear_in)?; let (b, t, _) = xs.dims3()?; let xs = xs.reshape((b, t, 2, ()))?; @@ -416,9 +405,43 @@ impl Module for RmsNorm { } } +#[derive(Debug, Clone)] +pub struct LayerNorm { + inner: candle_nn::LayerNorm, +} + +impl LayerNorm { + pub fn new(d_model: usize, eps: f32, vb: VarBuilder) -> Result { + let bias = vb.get(d_model, "bias")?; + let alpha = if vb.contains_tensor("alpha") { + vb.get((1, 1, d_model), "alpha")?.reshape(d_model)? + } else { + vb.get(d_model, "weight")?.reshape(d_model)? + }; + let inner = candle_nn::LayerNorm::new(alpha, bias, eps as f64); + Ok(Self { inner }) + } + + pub fn new_no_bias(d_model: usize, eps: f32, vb: VarBuilder) -> Result { + let alpha = if vb.contains_tensor("alpha") { + vb.get((1, 1, d_model), "alpha")?.reshape(d_model)? + } else { + vb.get(d_model, "weight")?.reshape(d_model)? + }; + let inner = candle_nn::LayerNorm::new_no_bias(alpha, eps as f64); + Ok(Self { inner }) + } +} + +impl Module for LayerNorm { + fn forward(&self, xs: &Tensor) -> Result { + self.inner.forward(xs) + } +} + #[derive(Debug, Clone)] pub enum Norm { - LayerNorm(candle_nn::LayerNorm), + LayerNorm(LayerNorm), RmsNorm(RmsNorm), } @@ -426,7 +449,7 @@ impl Norm { pub fn new(d_model: usize, cfg: &Config, vb: VarBuilder) -> Result { let norm = match cfg.norm { crate::NormType::LayerNorm => { - let norm = candle_nn::layer_norm(d_model, 1e-5, vb)?; + let norm = LayerNorm::new(d_model, 1e-5, vb)?; Self::LayerNorm(norm) } crate::NormType::RmsNorm => { @@ -455,7 +478,7 @@ pub struct StreamingTransformerLayer { norm2: Norm, layer_scale_1: Option, layer_scale_2: Option, - cross_attn: Option<(candle_nn::LayerNorm, StreamingMultiheadCrossAttention)>, + cross_attn: Option<(LayerNorm, StreamingMultiheadCrossAttention)>, norm_first: bool, span: tracing::Span, } @@ -469,8 +492,8 @@ impl StreamingTransformerLayer { let mlp = Mlp::new(cfg, vb.clone())?; let (norm1, norm2) = match cfg.norm { crate::NormType::LayerNorm => { - let norm1 = candle_nn::layer_norm(d_model, 1e-5, vb.pp("norm1"))?; - let norm2 = candle_nn::layer_norm(d_model, 1e-5, vb.pp("norm2"))?; + let norm1 = LayerNorm::new(d_model, 1e-5, vb.pp("norm1"))?; + let norm2 = LayerNorm::new(d_model, 1e-5, vb.pp("norm2"))?; (Norm::LayerNorm(norm1), Norm::LayerNorm(norm2)) } crate::NormType::RmsNorm => { @@ -494,12 +517,14 @@ impl StreamingTransformerLayer { } }; let self_attn = StreamingMultiheadAttention::new(rope, cfg, vb.pp("self_attn"))?; - let cross_attn = if cfg.cross_attention { - let norm_cross = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("norm_cross"))?; - let cross_attn = StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?; - Some((norm_cross, cross_attn)) - } else { - None + let cross_attn = match cfg.cross_attention { + Some(ca) => { + let norm_cross = LayerNorm::new_no_bias(cfg.d_model, 1e-5, vb.pp("norm_cross"))?; + let cross_attn = + StreamingMultiheadCrossAttention::new(ca, cfg, vb.pp("cross_attention"))?; + Some((norm_cross, cross_attn)) + } + None => None, }; Ok(Self { self_attn,