Skip to content

Commit

Permalink
Merge pull request #127 from kyutai-labs/transformer-update
Browse files Browse the repository at this point in the history
Transformer update.
  • Loading branch information
LaurentMazare authored Sep 30, 2024
2 parents 0882b66 + 4e91641 commit 1700456
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 60 deletions.
2 changes: 1 addition & 1 deletion rust/mimi-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ fn encodec_cfg(max_seq_len: Option<usize>) -> encodec::Config {
dim_feedforward: 2048,
kv_repeat: 1,
conv_layout: true, // see builders.py
cross_attention: false,
cross_attention: None,
max_seq_len: max_seq_len.unwrap_or(8192), // the transformer works at 25hz so this is ~5 mins.
};
encodec::Config {
Expand Down
2 changes: 1 addition & 1 deletion rust/moshi-core/src/encodec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl Config {
conv_kernel_size: 5,
use_conv_bias: true,
use_conv_block: false,
cross_attention: false,
cross_attention: None,
max_period: 10000,
gating: None,
norm: crate::NormType::LayerNorm,
Expand Down
8 changes: 4 additions & 4 deletions rust/moshi-core/src/lm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl Config {
max_period: 10000,
use_conv_block: false,
use_conv_bias: true,
cross_attention: false,
cross_attention: None,
gating: Some(candle_nn::Activation::Silu),
norm: crate::NormType::RmsNorm,
positional_embedding: transformer::PositionalEmbedding::Rope,
Expand All @@ -76,7 +76,7 @@ impl Config {
max_period: 10000,
use_conv_block: false,
use_conv_bias: true,
cross_attention: false,
cross_attention: None,
gating: Some(candle_nn::Activation::Silu),
norm: crate::NormType::RmsNorm,
positional_embedding: transformer::PositionalEmbedding::None,
Expand Down Expand Up @@ -122,7 +122,7 @@ impl Config {
max_period: 10000,
use_conv_block: false,
use_conv_bias: true,
cross_attention: true,
cross_attention: None,
gating: None,
norm: crate::NormType::LayerNorm,
positional_embedding: transformer::PositionalEmbedding::Rope,
Expand All @@ -145,7 +145,7 @@ impl Config {
max_period: 10000,
use_conv_block: false,
use_conv_bias: true,
cross_attention: false,
cross_attention: None,
gating: None,
norm: crate::NormType::LayerNorm,
positional_embedding: transformer::PositionalEmbedding::Sin,
Expand Down
18 changes: 10 additions & 8 deletions rust/moshi-core/src/quantized_transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// LICENSE file in the root directory of this source tree.

use crate::streaming::{StreamTensor, StreamingModule};
use crate::transformer::{get_mask, PositionalEmbedding, RotaryEmbedding};
use crate::transformer::{get_mask, CrossAttention, PositionalEmbedding, RotaryEmbedding};

use candle::{DType, IndexOp, Module, Result, Tensor, D};
use candle_transformers::quantized_nn::{layer_norm, linear_b, Linear};
Expand Down Expand Up @@ -169,7 +169,7 @@ pub struct StreamingMultiheadCrossAttention {
}

impl StreamingMultiheadCrossAttention {
pub fn new(_cfg: &Config, _vb: VarBuilder) -> Result<Self> {
pub fn new(_ca: CrossAttention, _cfg: &Config, _vb: VarBuilder) -> Result<Self> {
candle::bail!("cross-attn is not supported at the moment")
}

Expand Down Expand Up @@ -347,12 +347,14 @@ impl StreamingTransformerLayer {
}
};
let self_attn = StreamingMultiheadAttention::new(rope, cfg, vb.pp("self_attn"))?;
let cross_attn = if cfg.cross_attention {
let norm_cross = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_cross"))?;
let cross_attn = StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?;
Some((norm_cross, cross_attn))
} else {
None
let cross_attn = match cfg.cross_attention {
Some(ca) => {
let norm_cross = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_cross"))?;
let cross_attn =
StreamingMultiheadCrossAttention::new(ca, cfg, vb.pp("cross_attention"))?;
Some((norm_cross, cross_attn))
}
None => None,
};
Ok(Self {
self_attn,
Expand Down
117 changes: 71 additions & 46 deletions rust/moshi-core/src/transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ pub enum PositionalEmbedding {
None,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum CrossAttention {
Normal,
Gated,
}

#[derive(Debug, Clone)]
pub struct Config {
pub d_model: usize,
Expand All @@ -34,7 +40,7 @@ pub struct Config {
pub layer_scale: Option<f64>,
pub positional_embedding: PositionalEmbedding,
pub use_conv_block: bool,
pub cross_attention: bool,
pub cross_attention: Option<CrossAttention>,
pub conv_kernel_size: usize,
pub use_conv_bias: bool,
pub gating: Option<candle_nn::Activation>,
Expand Down Expand Up @@ -242,11 +248,12 @@ pub struct StreamingMultiheadCrossAttention {
kv_repeat: usize,
num_heads: usize,
neg_inf: Tensor,
tanh_gate_alpha: Option<Tensor>,
span: tracing::Span,
}

impl StreamingMultiheadCrossAttention {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
pub fn new(ca: CrossAttention, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let embed_dim = cfg.d_model;
let num_kv = cfg.num_heads / cfg.kv_repeat;
let kv_dim = num_kv * (embed_dim / cfg.num_heads);
Expand All @@ -269,6 +276,10 @@ impl StreamingMultiheadCrossAttention {
let in_proj_v = Linear::new(in_proj_weight_v, in_proj_bias_v);
let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?;
let tanh_gate_alpha = match ca {
CrossAttention::Gated => Some(vb.get((1, 1, 1), "tanh_gate.alpha")?.tanh()?),
CrossAttention::Normal => None,
};
Ok(Self {
in_proj_q,
in_proj_k,
Expand All @@ -277,6 +288,7 @@ impl StreamingMultiheadCrossAttention {
kv_repeat: cfg.kv_repeat,
num_heads: cfg.num_heads,
neg_inf,
tanh_gate_alpha,
span: tracing::span!(tracing::Level::TRACE, "mhca"),
})
}
Expand Down Expand Up @@ -320,39 +332,28 @@ impl StreamingMultiheadCrossAttention {
.transpose(1, 2)? // b,t,h,d
.reshape((b, t, hd))?
.apply(&self.out_proj)?;
let xs = match self.tanh_gate_alpha.as_ref() {
None => xs,
Some(alpha) => xs.broadcast_mul(alpha)?,
};
Ok(xs)
}
}

#[derive(Debug, Clone)]
pub enum Mlp {
NoGating {
span1: tracing::Span,
linear1: Linear,
span2: tracing::Span,
linear2: Linear,
span: tracing::Span,
},
Gating {
linear_in: Linear,
linear_out: Linear,
activation: candle_nn::Activation,
span: tracing::Span,
},
NoGating { linear1: Linear, linear2: Linear },
Gating { linear_in: Linear, linear_out: Linear, activation: candle_nn::Activation },
}

impl Mlp {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let d_model = cfg.d_model;
let span = tracing::span!(tracing::Level::TRACE, "mlp");

match cfg.gating {
None => {
let span1 = tracing::span!(tracing::Level::TRACE, "lin1");
let span2 = tracing::span!(tracing::Level::TRACE, "lin2");
let linear1 = linear(d_model, cfg.dim_feedforward, cfg.bias_ff, vb.pp("linear1"))?;
let linear2 = linear(cfg.dim_feedforward, d_model, cfg.bias_ff, vb.pp("linear2"))?;
Ok(Self::NoGating { linear1, linear2, span, span1, span2 })
Ok(Self::NoGating { linear1, linear2 })
}
Some(activation) => {
let vb = vb.pp("gating");
Expand All @@ -364,7 +365,7 @@ impl Mlp {
// TODO: Maybe use bias_ff here?
let linear_in = linear(d_model, 2 * hidden, false, vb.pp("linear_in"))?;
let linear_out = linear(hidden, d_model, false, vb.pp("linear_out"))?;
Ok(Self::Gating { linear_in, linear_out, activation, span })
Ok(Self::Gating { linear_in, linear_out, activation })
}
}
}
Expand All @@ -373,20 +374,8 @@ impl Mlp {
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::NoGating { linear1, linear2, span, span1, span2 } => {
let _enter = span.enter();
let xs = {
let _enter = span1.enter();
xs.apply(linear1)?
};
let xs = xs.gelu_erf()?;
{
let _enter = span2.enter();
xs.apply(linear2)
}
}
Self::Gating { linear_in, linear_out, activation, span } => {
let _enter = span.enter();
Self::NoGating { linear1, linear2 } => xs.apply(linear1)?.gelu_erf()?.apply(linear2),
Self::Gating { linear_in, linear_out, activation } => {
let xs = xs.apply(linear_in)?;
let (b, t, _) = xs.dims3()?;
let xs = xs.reshape((b, t, 2, ()))?;
Expand Down Expand Up @@ -416,17 +405,51 @@ impl Module for RmsNorm {
}
}

#[derive(Debug, Clone)]
pub struct LayerNorm {
inner: candle_nn::LayerNorm,
}

impl LayerNorm {
pub fn new(d_model: usize, eps: f32, vb: VarBuilder) -> Result<Self> {
let bias = vb.get(d_model, "bias")?;
let alpha = if vb.contains_tensor("alpha") {
vb.get((1, 1, d_model), "alpha")?.reshape(d_model)?
} else {
vb.get(d_model, "weight")?.reshape(d_model)?
};
let inner = candle_nn::LayerNorm::new(alpha, bias, eps as f64);
Ok(Self { inner })
}

pub fn new_no_bias(d_model: usize, eps: f32, vb: VarBuilder) -> Result<Self> {
let alpha = if vb.contains_tensor("alpha") {
vb.get((1, 1, d_model), "alpha")?.reshape(d_model)?
} else {
vb.get(d_model, "weight")?.reshape(d_model)?
};
let inner = candle_nn::LayerNorm::new_no_bias(alpha, eps as f64);
Ok(Self { inner })
}
}

impl Module for LayerNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.inner.forward(xs)
}
}

#[derive(Debug, Clone)]
pub enum Norm {
LayerNorm(candle_nn::LayerNorm),
LayerNorm(LayerNorm),
RmsNorm(RmsNorm),
}

impl Norm {
pub fn new(d_model: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let norm = match cfg.norm {
crate::NormType::LayerNorm => {
let norm = candle_nn::layer_norm(d_model, 1e-5, vb)?;
let norm = LayerNorm::new(d_model, 1e-5, vb)?;
Self::LayerNorm(norm)
}
crate::NormType::RmsNorm => {
Expand Down Expand Up @@ -455,7 +478,7 @@ pub struct StreamingTransformerLayer {
norm2: Norm,
layer_scale_1: Option<LayerScale>,
layer_scale_2: Option<LayerScale>,
cross_attn: Option<(candle_nn::LayerNorm, StreamingMultiheadCrossAttention)>,
cross_attn: Option<(LayerNorm, StreamingMultiheadCrossAttention)>,
norm_first: bool,
span: tracing::Span,
}
Expand All @@ -469,8 +492,8 @@ impl StreamingTransformerLayer {
let mlp = Mlp::new(cfg, vb.clone())?;
let (norm1, norm2) = match cfg.norm {
crate::NormType::LayerNorm => {
let norm1 = candle_nn::layer_norm(d_model, 1e-5, vb.pp("norm1"))?;
let norm2 = candle_nn::layer_norm(d_model, 1e-5, vb.pp("norm2"))?;
let norm1 = LayerNorm::new(d_model, 1e-5, vb.pp("norm1"))?;
let norm2 = LayerNorm::new(d_model, 1e-5, vb.pp("norm2"))?;
(Norm::LayerNorm(norm1), Norm::LayerNorm(norm2))
}
crate::NormType::RmsNorm => {
Expand All @@ -494,12 +517,14 @@ impl StreamingTransformerLayer {
}
};
let self_attn = StreamingMultiheadAttention::new(rope, cfg, vb.pp("self_attn"))?;
let cross_attn = if cfg.cross_attention {
let norm_cross = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("norm_cross"))?;
let cross_attn = StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?;
Some((norm_cross, cross_attn))
} else {
None
let cross_attn = match cfg.cross_attention {
Some(ca) => {
let norm_cross = LayerNorm::new_no_bias(cfg.d_model, 1e-5, vb.pp("norm_cross"))?;
let cross_attn =
StreamingMultiheadCrossAttention::new(ca, cfg, vb.pp("cross_attention"))?;
Some((norm_cross, cross_attn))
}
None => None,
};
Ok(Self {
self_attn,
Expand Down

0 comments on commit 1700456

Please sign in to comment.