From 743da2b43903168224541c81e9c7faeac396daf5 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Fri, 9 Aug 2024 19:14:57 -0400 Subject: [PATCH 1/5] Add the config --- mistralrs-core/src/models/mamba2.rs | 81 +++++++++++++++++++++++++++++ mistralrs-core/src/models/mod.rs | 1 + 2 files changed, 82 insertions(+) create mode 100644 mistralrs-core/src/models/mamba2.rs diff --git a/mistralrs-core/src/models/mamba2.rs b/mistralrs-core/src/models/mamba2.rs new file mode 100644 index 000000000..ac5396dca --- /dev/null +++ b/mistralrs-core/src/models/mamba2.rs @@ -0,0 +1,81 @@ +use candle_nn::Activation; +use serde::Deserialize; + +use crate::serde_default_fn; + +serde_default_fn!(usize, num_heads_default, 128); +serde_default_fn!(usize, head_dim_default, 64); +serde_default_fn!(usize, vocab_size_default, 32768); +serde_default_fn!(usize, hidden_size_default, 4096); +serde_default_fn!(usize, state_size_default, 128); +serde_default_fn!(usize, num_hidden_layers_default, 64); +serde_default_fn!(f64, layer_norm_epsilon_default, 1e-5); +serde_default_fn!(usize, expand_default, 2); +serde_default_fn!(usize, conv_kernel_default, 4); +serde_default_fn!(usize, n_groups_default, 2); +serde_default_fn!(bool, use_bias_default, false); +serde_default_fn!(bool, use_conv_bias_default, true); +serde_default_fn!(Activation, hidden_act_default, Activation::Silu); +serde_default_fn!(bool, residual_in_fp32_default, true); +serde_default_fn!(f64, time_step_min_default, 0.001); +serde_default_fn!(f64, time_step_max_default, 0.1); +serde_default_fn!(f64, time_step_floor_default, 0.0001); +serde_default_fn!((f64, f64), time_step_limit_default, (0.0, f64::INFINITY)); +serde_default_fn!(bool, rescale_prenorm_residual_default, false); +serde_default_fn!(bool, norm_before_gate_default, true); +serde_default_fn!(bool, rms_norm_default, true); +serde_default_fn!(usize, chunk_size_default, 256); + +#[derive(Debug, Clone, Deserialize, Default)] +pub enum TimeStepRank { + #[default] + Auto, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct Mamba2Config { + #[serde(default = "num_heads_default")] + num_heads: usize, + #[serde(default = "head_dim_default")] + head_dim: usize, + #[serde(default = "vocab_size_default")] + vocab_size: usize, + #[serde(default = "hidden_size_default")] + hidden_size: usize, + #[serde(default = "state_size_default")] + state_size: usize, + #[serde(default = "num_hidden_layers_default")] + num_hidden_layers: usize, + #[serde(default = "layer_norm_epsilon_default")] + layer_norm_epsilon: f64, + #[serde(default = "expand_default")] + expand: usize, + #[serde(default = "conv_kernel_default")] + conv_kernel: usize, + #[serde(default = "n_groups_default")] + n_groups: usize, + #[serde(default = "use_bias_default")] + use_bias: bool, + #[serde(default = "use_conv_bias_default")] + use_conv_bias: bool, + #[serde(default = "hidden_act_default")] + hidden_act: Activation, + #[serde(default = "residual_in_fp32_default")] + residual_in_fp32: bool, + #[serde(default = "Default::default")] + time_step_rank: TimeStepRank, + #[serde(default = "time_step_min_default")] + time_step_min: f64, + #[serde(default = "time_step_floor_default")] + time_step_floor: f64, + #[serde(default = "time_step_limit_default")] + time_step_limit: (f64, f64), + #[serde(default = "rescale_prenorm_residual_default")] + rescale_prenorm_residual: bool, + #[serde(default = "norm_before_gate_default")] + norm_before_gate: bool, + #[serde(default = "rms_norm_default")] + rms_norm: bool, + #[serde(default = "chunk_size_default")] + chunk_size: usize, +} diff --git a/mistralrs-core/src/models/mod.rs b/mistralrs-core/src/models/mod.rs index bf1b55976..a6f70271a 100644 --- a/mistralrs-core/src/models/mod.rs +++ b/mistralrs-core/src/models/mod.rs @@ -1,6 +1,7 @@ pub(crate) mod gemma; pub(crate) mod gemma2; pub(crate) mod llama; +pub(crate) mod mamba2; pub(crate) mod mistral; pub(crate) mod mixtral; pub(crate) mod phi2; From b6be723b9f0da7448cf440e2480edcc8c22620d7 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sat, 10 Aug 2024 14:24:21 -0400 Subject: [PATCH 2/5] Prep the skeleton --- mistralrs-core/src/layers.rs | 19 +++++++++++++++ mistralrs-core/src/models/mamba2.rs | 37 +++++++++++++++++++++++++++-- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index 0d172890c..4930d71ad 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -69,6 +69,25 @@ impl QRmsNorm { } } +pub struct GatedRmsNorm(RmsNorm); + +impl GatedRmsNorm { + pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { + let inner = candle_nn::rms_norm_non_quant(size, eps, vb)?; + let w = inner.inner().weight().clone(); + Ok(Self(RmsNorm { eps, weight: w })) + } + + fn forward(&self, x: &Tensor, gate: Option<&Tensor>) -> Result { + let x = if let Some(gate) = gate { + x.broadcast_mul(&gate.to_dtype(x.dtype())?)?.contiguous()? + } else { + x.contiguous()? + }; + candle_nn::ops::rms_norm(&x, &self.0.weight, self.0.eps as f32) + } +} + /// RoPE supporting LongRope #[derive(Debug, Clone)] pub struct PhiRotaryEmbedding { diff --git a/mistralrs-core/src/models/mamba2.rs b/mistralrs-core/src/models/mamba2.rs index ac5396dca..ee547f035 100644 --- a/mistralrs-core/src/models/mamba2.rs +++ b/mistralrs-core/src/models/mamba2.rs @@ -1,7 +1,14 @@ -use candle_nn::Activation; +use std::sync::Arc; + +use candle_core::Tensor; +use candle_nn::{Activation, Conv1d, Embedding, Linear}; +use mistralrs_quant::QuantMethod; use serde::Deserialize; -use crate::serde_default_fn; +use crate::{ + layers::{GatedRmsNorm, RmsNorm}, + serde_default_fn, +}; serde_default_fn!(usize, num_heads_default, 128); serde_default_fn!(usize, head_dim_default, 64); @@ -66,6 +73,8 @@ pub struct Mamba2Config { time_step_rank: TimeStepRank, #[serde(default = "time_step_min_default")] time_step_min: f64, + #[serde(default = "time_step_max_default")] + time_step_max: f64, #[serde(default = "time_step_floor_default")] time_step_floor: f64, #[serde(default = "time_step_limit_default")] @@ -79,3 +88,27 @@ pub struct Mamba2Config { #[serde(default = "chunk_size_default")] chunk_size: usize, } + +// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mamba2/modeling_mamba2.py#L406 +struct Mixer { + conv1d: Conv1d, + in_proj: Arc, + dt_bias: Tensor, + a_log: Tensor, + d: Tensor, + norm: GatedRmsNorm, + out_proj: Arc, +} + +struct Layer { + norm: RmsNorm, + mixer: Mixer, + res_in_f32: bool, +} + +pub struct Model { + lm_head: Linear, + embeddings: Embedding, + norm_f: RmsNorm, + layers: Vec, +} From 406288f1f8a8e63a6fc0df060f32d0ec496bc895 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Tue, 20 Aug 2024 20:46:28 -0400 Subject: [PATCH 3/5] Add some layers --- mistralrs-core/src/models/mamba2.rs | 56 ++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/mistralrs-core/src/models/mamba2.rs b/mistralrs-core/src/models/mamba2.rs index ee547f035..54c896d60 100644 --- a/mistralrs-core/src/models/mamba2.rs +++ b/mistralrs-core/src/models/mamba2.rs @@ -1,12 +1,16 @@ use std::sync::Arc; -use candle_core::Tensor; -use candle_nn::{Activation, Conv1d, Embedding, Linear}; -use mistralrs_quant::QuantMethod; +use candle_core::{Result, Tensor}; +use candle_nn::{Activation, Conv1d, Embedding, Linear, VarBuilder}; +use mistralrs_quant::{ + linear, linear_no_bias, QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear, +}; use serde::Deserialize; use crate::{ layers::{GatedRmsNorm, RmsNorm}, + paged_attention::AttentionImplementation, + pipeline::NormalLoadingMetadata, serde_default_fn, }; @@ -87,6 +91,7 @@ pub struct Mamba2Config { rms_norm: bool, #[serde(default = "chunk_size_default")] chunk_size: usize, + quantization_config: Option, } // https://github.com/huggingface/transformers/blob/main/src/transformers/models/mamba2/modeling_mamba2.py#L406 @@ -107,8 +112,51 @@ struct Layer { } pub struct Model { - lm_head: Linear, + lm_head: Arc, embeddings: Embedding, norm_f: RmsNorm, layers: Vec, } + +impl Model { + pub fn new( + cfg: &Mamba2Config, + vb: VarBuilder, + _is_gptx: bool, + normal_loading_metadata: NormalLoadingMetadata, + attention_mechanism: AttentionImplementation, + ) -> Result { + if let Some(ref quant_cfg) = &cfg.quantization_config { + tracing::info!( + "Using {} quantization in {} bits.", + quant_cfg.quant_method.to_string(), + quant_cfg.bits + ); + } + let mapper = normal_loading_metadata.mapper; + + let vb = vb.pp("backbone"); + + let embeddings = candle_nn::embedding( + cfg.vocab_size, + cfg.hidden_size, + mapper.set_nm_device(vb.pp("embeddings"), false), + )?; + // Tied lm_head... + let lm_head = Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized( + Linear::new( + mapper.cast_nm_device( + &embeddings.embeddings(), + normal_loading_metadata.loading_isq, + )?, + None, + ), + ))?); + let norm_f = RmsNorm::new( + cfg.hidden_size, + cfg.layer_norm_epsilon, + mapper.set_nm_device(vb.pp("norm_f"), false), + )?; + todo!() + } +} From f12c9a966c9a2ef442bf81d5275baa49c7b0d28b Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 21 Aug 2024 14:13:23 -0400 Subject: [PATCH 4/5] Add the loading --- mistralrs-core/src/models/mamba2.rs | 127 ++++++++++++++++++++++------ 1 file changed, 100 insertions(+), 27 deletions(-) diff --git a/mistralrs-core/src/models/mamba2.rs b/mistralrs-core/src/models/mamba2.rs index 54c896d60..950c15e3a 100644 --- a/mistralrs-core/src/models/mamba2.rs +++ b/mistralrs-core/src/models/mamba2.rs @@ -1,10 +1,12 @@ +#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] + use std::sync::Arc; use candle_core::{Result, Tensor}; -use candle_nn::{Activation, Conv1d, Embedding, Linear, VarBuilder}; -use mistralrs_quant::{ - linear, linear_no_bias, QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear, +use candle_nn::{ + conv1d, conv1d_no_bias, Activation, Conv1d, Conv1dConfig, Embedding, Linear, VarBuilder, }; +use mistralrs_quant::{linear_b, QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear}; use serde::Deserialize; use crate::{ @@ -21,7 +23,7 @@ serde_default_fn!(usize, hidden_size_default, 4096); serde_default_fn!(usize, state_size_default, 128); serde_default_fn!(usize, num_hidden_layers_default, 64); serde_default_fn!(f64, layer_norm_epsilon_default, 1e-5); -serde_default_fn!(usize, expand_default, 2); +serde_default_fn!(f64, expand_default, 2.0); serde_default_fn!(usize, conv_kernel_default, 4); serde_default_fn!(usize, n_groups_default, 2); serde_default_fn!(bool, use_bias_default, false); @@ -60,7 +62,7 @@ pub struct Mamba2Config { #[serde(default = "layer_norm_epsilon_default")] layer_norm_epsilon: f64, #[serde(default = "expand_default")] - expand: usize, + expand: f64, #[serde(default = "conv_kernel_default")] conv_kernel: usize, #[serde(default = "n_groups_default")] @@ -105,12 +107,86 @@ struct Mixer { out_proj: Arc, } +impl Mixer { + fn new(cfg: &Mamba2Config, vb: VarBuilder) -> Result { + let intermediate_size = (cfg.expand * cfg.hidden_size as f64) as usize; + let conv_dim = intermediate_size + 2 * cfg.n_groups * cfg.state_size; + let projection_size = intermediate_size + conv_dim + cfg.num_heads; + + let conv1d_fn = if cfg.use_conv_bias { + conv1d + } else { + conv1d_no_bias + }; + + let conv1d = conv1d_fn( + conv_dim, + conv_dim, + cfg.conv_kernel, + Conv1dConfig { + padding: cfg.conv_kernel - 1, + groups: conv_dim, + stride: 1, + dilation: 1, + }, + vb.pp("conv1d"), + )?; + + let in_proj = linear_b( + cfg.hidden_size, + projection_size, + cfg.use_bias, + &cfg.quantization_config, + vb.pp("in_proj"), + )?; + + let out_proj = linear_b( + intermediate_size, + cfg.hidden_size, + cfg.use_bias, + &cfg.quantization_config, + vb.pp("in_proj"), + )?; + + // Time step proj, discretization + let dt_bias = vb.get((cfg.num_heads,), "dt_bias")?; + + // S4D real init, not discretized + let a_log = vb.get((1, cfg.num_heads + 1), "A_log")?; + let d = vb.get((cfg.num_heads,), "D")?; + + let norm = GatedRmsNorm::new(intermediate_size, cfg.layer_norm_epsilon, vb.pp("norm"))?; + + Ok(Self { + conv1d, + in_proj, + out_proj, + dt_bias, + a_log, + d, + norm, + }) + } +} + struct Layer { norm: RmsNorm, mixer: Mixer, res_in_f32: bool, } +impl Layer { + fn new(cfg: &Mamba2Config, vb: VarBuilder) -> Result { + let norm = RmsNorm::new(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm"))?; + let mixer = Mixer::new(cfg, vb.pp("mixer"))?; + Ok(Self { + norm, + mixer, + res_in_f32: cfg.residual_in_fp32, + }) + } +} + pub struct Model { lm_head: Arc, embeddings: Embedding, @@ -123,8 +199,8 @@ impl Model { cfg: &Mamba2Config, vb: VarBuilder, _is_gptx: bool, - normal_loading_metadata: NormalLoadingMetadata, - attention_mechanism: AttentionImplementation, + _normal_loading_metadata: NormalLoadingMetadata, + _attention_mechanism: AttentionImplementation, ) -> Result { if let Some(ref quant_cfg) = &cfg.quantization_config { tracing::info!( @@ -133,30 +209,27 @@ impl Model { quant_cfg.bits ); } - let mapper = normal_loading_metadata.mapper; let vb = vb.pp("backbone"); - let embeddings = candle_nn::embedding( - cfg.vocab_size, - cfg.hidden_size, - mapper.set_nm_device(vb.pp("embeddings"), false), - )?; - // Tied lm_head... + let embeddings = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embeddings"))?; + // Tied to lm_head... let lm_head = Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized( - Linear::new( - mapper.cast_nm_device( - &embeddings.embeddings(), - normal_loading_metadata.loading_isq, - )?, - None, - ), + Linear::new(embeddings.embeddings().clone(), None), ))?); - let norm_f = RmsNorm::new( - cfg.hidden_size, - cfg.layer_norm_epsilon, - mapper.set_nm_device(vb.pp("norm_f"), false), - )?; - todo!() + let norm_f = RmsNorm::new(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm_f"))?; + + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + for idx in 0..cfg.num_hidden_layers { + layers.push(Layer::new(cfg, vb.pp(idx))?); + } + + Ok(Self { + lm_head, + embeddings, + norm_f, + layers, + }) } } From c0b93f9a448fa3e88b47fafb0c847d5f79defb0b Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 21 Aug 2024 16:48:24 -0400 Subject: [PATCH 5/5] Oops --- mistralrs-core/src/models/mamba2.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/mamba2.rs b/mistralrs-core/src/models/mamba2.rs index 950c15e3a..ef4939d59 100644 --- a/mistralrs-core/src/models/mamba2.rs +++ b/mistralrs-core/src/models/mamba2.rs @@ -145,7 +145,7 @@ impl Mixer { cfg.hidden_size, cfg.use_bias, &cfg.quantization_config, - vb.pp("in_proj"), + vb.pp("out_proj"), )?; // Time step proj, discretization