Skip to content

Commit

Permalink
Add a quantized blip model. (#1155)
Browse files Browse the repository at this point in the history
* Add a quantized blip model.

* Integrate the quantized blip model to the actual example.
  • Loading branch information
LaurentMazare authored Oct 22, 2023
1 parent 8a82d62 commit a11af79
Show file tree
Hide file tree
Showing 5 changed files with 795 additions and 17 deletions.
70 changes: 53 additions & 17 deletions candle-examples/examples/blip/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,25 @@ use candle::{DType, Device, Result, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::blip;
use candle_transformers::models::quantized_blip;

use tokenizers::Tokenizer;

enum Model {
M(blip::BlipForConditionalGeneration),
Q(quantized_blip::BlipForConditionalGeneration),
}

impl Model {
fn text_decoder_forward(&mut self, xs: &Tensor, img_xs: &Tensor) -> Result<Tensor> {
match self {
Self::M(m) => m.text_decoder().forward(xs, img_xs),
Self::Q(m) => m.text_decoder().forward(xs, img_xs),
}
}
}

// TODO: Maybe add support for the conditional prompt.
#[derive(Parser)]
struct Args {
#[arg(long)]
Expand All @@ -28,6 +44,10 @@ struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,

/// Use the quantized version of the model.
#[arg(long)]
quantized: bool,
}

const SEP_TOKEN_ID: u32 = 102;
Expand All @@ -54,20 +74,20 @@ pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();

let device = candle_examples::device(args.cpu)?;

let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");

let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.repo(hf_hub::Repo::with_revision(
"Salesforce/blip-image-captioning-large".to_string(),
hf_hub::RepoType::Model,
"refs/pr/18".to_string(),
));
api.get("model.safetensors")?
if args.quantized {
let api = api.model("lmz/candle-blip".to_string());
api.get("blip-image-captioning-large-q4k.gguf")?
} else {
let api = api.repo(hf_hub::Repo::with_revision(
"Salesforce/blip-image-captioning-large".to_string(),
hf_hub::RepoType::Model,
"refs/pr/18".to_string(),
));
api.get("model.safetensors")?
}
}
Some(model) => model.into(),
};
Expand All @@ -84,19 +104,35 @@ pub fn main() -> anyhow::Result<()> {
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);

let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let config = blip::Config::image_captioning_large();
let mut model = blip::BlipForConditionalGeneration::new(&config, vb)?;
println!("model built");
// TODO: Maybe add support for the conditional prompt.
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;

let (image_embeds, device, mut model) = if args.quantized {
let device = Device::Cpu;
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");

let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
(image_embeds, device, Model::Q(model))
} else {
let device = candle_examples::device(args.cpu)?;
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");

let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = blip::BlipForConditionalGeneration::new(&config, vb)?;
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
(image_embeds, device, Model::M(model))
};

let mut token_ids = vec![30522u32];
for index in 0..1000 {
let context_size = if index > 0 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = model.text_decoder().forward(&input_ids, &image_embeds)?;
let logits = model.text_decoder_forward(&input_ids, &image_embeds)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
Expand Down
2 changes: 2 additions & 0 deletions candle-transformers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pub mod llama;
pub mod mistral;
pub mod mixformer;
pub mod mpt;
pub mod quantized_blip;
pub mod quantized_blip_text;
pub mod quantized_llama;
pub mod quantized_mistral;
pub mod quantized_mixformer;
Expand Down
258 changes: 258 additions & 0 deletions candle-transformers/src/models/quantized_blip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
use super::quantized_blip_text as blip_text;
use crate::quantized_nn::{layer_norm, linear, Linear};
pub use crate::quantized_var_builder::VarBuilder;
use candle::{Module, Result, Tensor, D};
use candle_nn::{Conv2d, Conv2dConfig, LayerNorm};

pub type VisionConfig = super::blip::VisionConfig;
pub type Config = super::blip::Config;

#[derive(Debug, Clone)]
struct VisionEmbeddings {
class_embedding: Tensor,
patch_embedding: Conv2d,
position_embedding: Tensor,
}

impl VisionEmbeddings {
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
let class_embedding = vb
.get((1, 1, cfg.hidden_size), "class_embedding")?
.dequantize(vb.device())?;
let conv_cfg = Conv2dConfig {
stride: cfg.patch_size,
..Default::default()
};
let pe_vb = vb.pp("patch_embedding");
let pe_weight = pe_vb
.get(
(cfg.hidden_size, 3, cfg.patch_size, cfg.patch_size),
"weight",
)?
.dequantize(vb.device())?;
let pe_bias = pe_vb
.get(cfg.hidden_size, "bias")?
.dequantize(vb.device())?;

let patch_embedding = Conv2d::new(pe_weight, Some(pe_bias), conv_cfg);
let num_patches1 = cfg.image_size / cfg.patch_size;
let num_patches = num_patches1 * num_patches1;
let num_positions = num_patches + 1;
let position_embedding = vb
.get((1, num_positions, cfg.hidden_size), "position_embedding")?
.dequantize(vb.device())?;
Ok(Self {
class_embedding,
patch_embedding,
position_embedding,
})
}
}

impl Module for VisionEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let target_dtype = xs.dtype();
let b_size = xs.dim(0)?;
let patch_embeds = xs.apply(&self.patch_embedding)?.flatten_from(2)?.t()?;
let d = self.class_embedding.dim(D::Minus1)?;
let class_embeds = self
.class_embedding
.broadcast_as((b_size, 1, d))?
.to_dtype(target_dtype)?;
let embeddings = Tensor::cat(&[&class_embeds, &patch_embeds], 1)?;
let position_embedding = self.position_embedding.narrow(1, 0, embeddings.dim(1)?)?;
embeddings.broadcast_add(&position_embedding)
}
}

#[derive(Debug, Clone)]
struct Attention {
qkv: Linear,
projection: Linear,
scale: f64,
num_heads: usize,
}

impl Attention {
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
let embed_dim = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let head_dim = embed_dim / num_heads;
let scale = 1f64 / (head_dim as f64).sqrt();
let qkv = linear(embed_dim, 3 * embed_dim, vb.pp("qkv"))?;
let projection = linear(embed_dim, embed_dim, vb.pp("projection"))?;
Ok(Self {
qkv,
projection,
scale,
num_heads,
})
}

fn forward(&self, xs: &Tensor, attn_mask: Option<&Tensor>) -> Result<Tensor> {
let (b_sz, tgt_len, embed_dim) = xs.dims3()?;
let mixed_qkv = xs
.apply(&self.qkv)?
.reshape((b_sz, tgt_len, 3, self.num_heads, embed_dim / self.num_heads))?
.permute((2, 0, 3, 1, 4))?;
let query = mixed_qkv.get(0)?;
let key = mixed_qkv.get(1)?;
let value = mixed_qkv.get(2)?;
let attention_scores = query.matmul(&key.t()?)?;
let attention_scores = (attention_scores * self.scale)?;
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
let attention_probs = match attn_mask {
None => attention_probs,
Some(attn_mask) => (attention_probs * attn_mask)?,
};
attention_probs
.matmul(&value)?
.permute((0, 2, 1, 3))?
.flatten_from(D::Minus2)?
.apply(&self.projection)
}
}

#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
activation_fn: candle_nn::Activation,
fc1: Linear,
fc2: Linear,
}

impl MLP {
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?;
let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
Ok(Self {
activation_fn: cfg.hidden_act,
fc1,
fc2,
})
}
}

impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.fc1)?
.apply(&self.activation_fn)?
.apply(&self.fc2)
}
}

#[derive(Debug, Clone)]
struct EncoderLayer {
self_attn: Attention,
layer_norm1: LayerNorm,
mlp: MLP,
layer_norm2: LayerNorm,
}

impl EncoderLayer {
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
let embed_dim = cfg.hidden_size;
let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
let layer_norm1 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm1"))?;
let layer_norm2 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm2"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
Ok(Self {
self_attn,
layer_norm1,
mlp,
layer_norm2,
})
}

fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let residual = xs;
let xs = xs.apply(&self.layer_norm1)?;
let xs = self.self_attn.forward(&xs, attention_mask)?;
let xs = (xs + residual)?;

let residual = &xs;
let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?;
xs + residual
}
}

#[derive(Debug, Clone)]
struct Encoder {
layers: Vec<EncoderLayer>,
}

impl Encoder {
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb = vb.pp("layers");
for i in 0..cfg.num_hidden_layers {
let layer = EncoderLayer::new(cfg, vb.pp(i))?;
layers.push(layer)
}
Ok(Self { layers })
}

fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, attention_mask)?
}
Ok(xs)
}
}

#[derive(Debug, Clone)]
pub struct VisionModel {
embeddings: VisionEmbeddings,
encoder: Encoder,
post_layernorm: LayerNorm,
}

impl VisionModel {
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?;
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
let post_layernorm =
layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?;
Ok(Self {
embeddings,
encoder,
post_layernorm,
})
}
}

impl Module for VisionModel {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.embeddings)?;
let encoder_outputs = self.encoder.forward(&xs, None)?;
// Return the last hidden state rather than pooled outputs.
encoder_outputs.apply(&self.post_layernorm)
}
}

#[derive(Debug, Clone)]
pub struct BlipForConditionalGeneration {
vision_model: VisionModel,
text_decoder: blip_text::TextLMHeadModel,
}

impl BlipForConditionalGeneration {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vision_model = VisionModel::new(&cfg.vision_config, vb.pp("vision_model"))?;
let text_decoder =
blip_text::TextLMHeadModel::new(&cfg.text_config, vb.pp("text_decoder"))?;
Ok(Self {
vision_model,
text_decoder,
})
}

pub fn vision_model(&self) -> &VisionModel {
&self.vision_model
}

pub fn text_decoder(&mut self) -> &mut blip_text::TextLMHeadModel {
&mut self.text_decoder
}
}
Loading

0 comments on commit a11af79

Please sign in to comment.