Skip to content

Commit

Permalink
Add device mapping support
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Sep 29, 2024
1 parent f64dbf0 commit 21f282c
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 27 deletions.
4 changes: 2 additions & 2 deletions mistralrs-core/src/vision_models/mllama/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ impl MLlamaModel {
&cfg.text_config,
vb.pp("language_model"),
is_gptx,
&normal_loading_metadata,
&attention_mechanism,
normal_loading_metadata,
attention_mechanism,
)?,
multi_modal_projector: linear(
cfg.vision_config.vision_output_dim,
Expand Down
114 changes: 89 additions & 25 deletions mistralrs-core/src/vision_models/mllama/text.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]

use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};

use candle_core::{Device, IndexOp, Result, Tensor};
use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder};

use crate::{
attention::SdpaParams,
device_map::DeviceMapper,
layers::{repeat_kv, CausalMasker, Llama3RotaryEmbedding, MatMul, RmsNorm, Sdpa},
layers_masker::PastKvLenCache,
paged_attention::{AttentionImplementation, ModelConfigMetadata},
Expand Down Expand Up @@ -154,16 +155,26 @@ impl MLlamaSelfAttentionDecoderLayer {
cfg: &MLlamaTextConfig,
vb: VarBuilder,
rope: Arc<Llama3RotaryEmbedding>,
mapper: &dyn DeviceMapper,
layer_idx: usize,
loading_isq: bool,
) -> Result<Self> {
let mlp = MLlamaTextMlp::new(cfg, vb.pp("mlp"))?;
let input_layernorm =
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let mlp = MLlamaTextMlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
let input_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_nm_device(vb.pp("input_layernorm"), false),
)?;
let post_attention_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
mapper.set_nm_device(vb.pp("post_attention_layernorm"), false),
)?;
let attn = MLlamaTextSelfAttention::new(
cfg,
mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
rope,
)?;
let attn = MLlamaTextSelfAttention::new(cfg, vb.pp("self_attn"), rope)?;

Ok(Self {
attn,
Expand Down Expand Up @@ -316,16 +327,28 @@ struct MLlamaCrossAttentionDecoderLayer {
}

impl MLlamaCrossAttentionDecoderLayer {
fn new(cfg: &MLlamaTextConfig, vb: VarBuilder) -> Result<Self> {
let mlp = MLlamaTextMlp::new(cfg, vb.pp("mlp"))?;
let input_layernorm =
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
fn new(
cfg: &MLlamaTextConfig,
vb: VarBuilder,
mapper: &dyn DeviceMapper,
layer_idx: usize,
loading_isq: bool,
) -> Result<Self> {
let mlp = MLlamaTextMlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
let input_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_nm_device(vb.pp("input_layernorm"), false),
)?;
let post_attention_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
mapper.set_nm_device(vb.pp("post_attention_layernorm"), false),
)?;
let attn = MLlamaTextCrossAttention::new(
cfg,
mapper.set_device(layer_idx, vb.pp("cross_attn"), loading_isq),
)?;
let attn = MLlamaTextCrossAttention::new(cfg, vb.pp("cross_attn"))?;

Ok(Self {
attn,
Expand Down Expand Up @@ -384,51 +407,89 @@ pub(super) struct MLlamaTextModel {
pub(crate) self_attn_cache: Cache,
pub(crate) device: Device,
pub(crate) max_position_embeddings: usize,
mapper: Box<dyn DeviceMapper + Send + Sync>,
}

impl MLlamaTextModel {
pub(super) fn new(
cfg: &MLlamaTextConfig,
vb: VarBuilder,
is_gptx: bool,
_normal_loading_metadata: &NormalLoadingMetadata,
_attention_mechanism: &AttentionImplementation,
normal_loading_metadata: NormalLoadingMetadata,
attention_mechanism: AttentionImplementation,
) -> Result<Self> {
if !matches!(attention_mechanism, AttentionImplementation::Eager) {
candle_core::bail!("Expected eager attention implementation");
}
let mapper = normal_loading_metadata.mapper;

let embed_tokens = embedding(
cfg.vocab_size + 8,
cfg.hidden_size,
vb.pp("model.embed_tokens"),
mapper.set_nm_device(vb.pp("model.embed_tokens"), false),
)?;

let lm_head = if cfg.tie_word_embeddings {
Linear::new(embed_tokens.embeddings().clone(), None)
} else {
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
linear_no_bias(
cfg.hidden_size,
cfg.vocab_size,
mapper.set_nm_device(vb.pp("lm_head"), false),
)?
};

let vb = vb.pp("model");

let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm"))?;
let norm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_nm_device(vb.pp("norm"), false),
)?;

let rope = Arc::new(Llama3RotaryEmbedding::new_mllama3(
vb.dtype(),
cfg,
vb.device(),
is_gptx,
)?);
let mut ropes = HashMap::new();
for layer_idx in 0..cfg.num_hidden_layers {
let device = mapper
.device_for(layer_idx, false)
.unwrap_or(&normal_loading_metadata.real_device);
ropes.insert(
device.location(),
Arc::new(Llama3RotaryEmbedding::new_mllama3(
vb.dtype(),
cfg,
vb.device(),
is_gptx,
)?),
);
}

let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers {
if cfg.cross_attention_layers.contains(&i) {
layers.push(MLlamaDecoderLayer::CrossAttn(
MLlamaCrossAttentionDecoderLayer::new(cfg, vb.pp(format!("layers.{i}")))?,
MLlamaCrossAttentionDecoderLayer::new(
cfg,
vb.pp(format!("layers.{i}")),
&*mapper,
i,
normal_loading_metadata.loading_isq,
)?,
))
} else {
let device = mapper
.device_for(i, false)
.unwrap_or(&normal_loading_metadata.real_device);
layers.push(MLlamaDecoderLayer::SelfAttn(
MLlamaSelfAttentionDecoderLayer::new(
cfg,
vb.pp(format!("layers.{i}")),
rope.clone(),
ropes
.get(&device.location())
.expect("No RoPE for device location!")
.clone(),
&*mapper,
i,
normal_loading_metadata.loading_isq,
)?,
))
}
Expand All @@ -450,6 +511,7 @@ impl MLlamaTextModel {
self_attn_cache: Cache::new(cfg.num_hidden_layers, false),
device: vb.device().clone(),
max_position_embeddings: cfg.max_position_embeddings,
mapper,
})
}

Expand All @@ -475,6 +537,7 @@ impl MLlamaTextModel {
)?;

for (i, layer) in self.layers.iter().enumerate() {
hidden_states = self.mapper.map(hidden_states, i)?;
match layer {
MLlamaDecoderLayer::SelfAttn(attn) => {
hidden_states = attn.forward(
Expand Down Expand Up @@ -503,6 +566,7 @@ impl MLlamaTextModel {
}
}

hidden_states = hidden_states.to_device(&self.device)?;
hidden_states = self.norm.forward(&hidden_states)?;

hidden_states = self
Expand Down

0 comments on commit 21f282c

Please sign in to comment.