Skip to content

Commit

Permalink
Add the SigLIP model.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 28, 2024
1 parent 62525e8 commit 657bf64
Show file tree
Hide file tree
Showing 3 changed files with 315 additions and 6 deletions.
12 changes: 6 additions & 6 deletions candle-transformers/src/models/fastvit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ use candle_nn::{
BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
};

#[derive(Clone, Debug)]
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
pub struct Config {
exp_ratio: usize,
in_channels: usize,
blocks: [usize; 4],
attn: bool,
lkc_use_act: bool,
pub exp_ratio: usize,
pub in_channels: usize,
pub blocks: [usize; 4],
pub attn: bool,
pub lkc_use_act: bool,
}

impl Config {
Expand Down
1 change: 1 addition & 0 deletions candle-transformers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ pub mod rwkv_v5;
pub mod rwkv_v6;
pub mod segformer;
pub mod segment_anything;
pub mod siglip;
pub mod stable_diffusion;
pub mod stable_lm;
pub mod starcoder2;
Expand Down
308 changes: 308 additions & 0 deletions candle-transformers/src/models/siglip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
#![allow(unused)]
use candle::{Result, Tensor};
use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder};

// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L27
#[derive(serde::Deserialize, Clone, Debug)]
pub struct TextConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub max_position_embeddings: usize,
pub hidden_act: candle_nn::Activation,
pub layer_norm_eps: f64,
pub pad_token_id: u32,
pub bos_token_id: u32,
pub eos_token_id: u32,
}

// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L132
#[derive(serde::Deserialize, Clone, Debug)]
pub struct VisionConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_channels: usize,
pub image_size: usize,
pub patch_size: usize,
pub hidden_act: candle_nn::Activation,
pub layer_norm_eps: f64,
}

trait TransformerConfig {
fn hidden_size(&self) -> usize;
fn intermediate_size(&self) -> usize;
fn num_attention_heads(&self) -> usize;
fn num_hidden_layers(&self) -> usize;
fn layer_norm_eps(&self) -> f64;
fn hidden_act(&self) -> candle_nn::Activation;
}

impl TransformerConfig for TextConfig {
fn hidden_size(&self) -> usize {
self.hidden_size
}
fn intermediate_size(&self) -> usize {
self.intermediate_size
}
fn num_attention_heads(&self) -> usize {
self.num_attention_heads
}
fn num_hidden_layers(&self) -> usize {
self.num_hidden_layers
}
fn layer_norm_eps(&self) -> f64 {
self.layer_norm_eps
}
fn hidden_act(&self) -> candle_nn::Activation {
self.hidden_act
}
}

impl TransformerConfig for VisionConfig {
fn hidden_size(&self) -> usize {
self.hidden_size
}
fn intermediate_size(&self) -> usize {
self.intermediate_size
}
fn num_attention_heads(&self) -> usize {
self.num_attention_heads
}
fn num_hidden_layers(&self) -> usize {
self.num_hidden_layers
}
fn layer_norm_eps(&self) -> f64 {
self.layer_norm_eps
}
fn hidden_act(&self) -> candle_nn::Activation {
self.hidden_act
}
}

// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L228
#[derive(serde::Deserialize, Clone, Debug)]
pub struct Config {
pub text_config: TextConfig,
pub vision_config: VisionConfig,
}

#[derive(Debug, Clone)]
struct MultiheadAttentionPoolingHead {}

#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
num_heads: usize,
head_dim: usize,
scale: f64,
}

impl Attention {
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
let embed_dim = cfg.hidden_size();
let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
let out_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
let num_heads = cfg.num_attention_heads();
let head_dim = embed_dim / num_heads;
Ok(Self {
q_proj,
k_proj,
v_proj,
out_proj,
num_heads,
head_dim,
scale: (head_dim as f64).powf(-0.5),
})
}

fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let (batch_size, q_len, _) = xs.dims3()?;
let query_states = xs.apply(&self.q_proj)?;
let key_states = xs.apply(&self.k_proj)?;
let value_states = xs.apply(&self.v_proj)?;

let shape = (batch_size, q_len, self.num_heads, self.head_dim);
let query_states = query_states.reshape(shape)?.transpose(1, 2)?;
let key_states = key_states.reshape(shape)?.transpose(1, 2)?;
let value_states = value_states.reshape(shape)?.transpose(1, 2)?;

let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
// The original implementation upcasts to f32 but candle_nn::ops::softmax should handle this properly.
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_outputs = attn_weights
.matmul(&value_states)?
.transpose(1, 2)?
.reshape((batch_size, q_len, ()))?
.apply(&self.out_proj)?;
Ok(attn_outputs)
}
}

#[derive(Debug, Clone)]
struct Mlp {
fc1: Linear,
fc2: Linear,
activation_fn: candle_nn::Activation,
}

impl Mlp {
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size();
let intermediate_size = cfg.intermediate_size();
let hidden_act = cfg.hidden_act();
let fc1 = candle_nn::linear(hidden_size, intermediate_size, vb.pp("fc1"))?;
let fc2 = candle_nn::linear(intermediate_size, hidden_size, vb.pp("fc2"))?;
Ok(Self {
fc1,
fc2,
activation_fn: cfg.hidden_act(),
})
}
}

impl candle::Module for Mlp {
fn forward(&self, xs: &candle::Tensor) -> Result<candle::Tensor> {
xs.apply(&self.fc1)?
.apply(&self.activation_fn)?
.apply(&self.fc2)
}
}

#[derive(Debug, Clone)]
struct EncoderLayer {
self_attn: Attention,
layer_norm1: LayerNorm,
mlp: Mlp,
layer_norm2: LayerNorm,
}

impl EncoderLayer {
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size();
let layer_norm_eps = cfg.layer_norm_eps();
let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
let layer_norm1 = layer_norm(hidden_size, layer_norm_eps, vb.pp("layer_norm1"))?;
let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
let layer_norm2 = layer_norm(hidden_size, layer_norm_eps, vb.pp("layer_norm2"))?;
Ok(Self {
self_attn,
layer_norm1,
mlp,
layer_norm2,
})
}

fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let residual = xs;
let xs = xs.apply(&self.layer_norm1)?;
let xs = self.self_attn.forward(&xs, attention_mask)?;
let xs = (residual + xs)?;
let residual = &xs;
let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?;
xs + residual
}
}

#[derive(Debug, Clone)]
struct Encoder {
layers: Vec<EncoderLayer>,
}

impl Encoder {
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
let mut layers = vec![];
let vb = vb.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers() {
let layer = EncoderLayer::new(cfg, vb.pp(layer_idx))?;
layers.push(layer)
}
Ok(Self { layers })
}

fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, attention_mask)?
}
Ok(xs)
}
}

#[derive(Debug, Clone)]
struct VisionEmbeddings {
patch_embedding: candle_nn::Conv2d,
position_embedding: candle_nn::Embedding,
position_ids: Tensor,
}

impl VisionEmbeddings {
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
let conv2d_cfg = candle_nn::Conv2dConfig {
stride: cfg.patch_size,
..Default::default()
};
let patch_embedding = candle_nn::conv2d(
cfg.num_channels,
cfg.hidden_size,
cfg.patch_size,
conv2d_cfg,
vb.pp("patch_embedding"),
)?;
let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
let position_ids = Tensor::arange(0, num_patches as i64, vb.device())?;
let position_embedding =
candle_nn::embedding(num_patches, cfg.hidden_size(), vb.pp("position_embedding"))?;
Ok(Self {
patch_embedding,
position_embedding,
position_ids,
})
}
}

#[derive(Debug, Clone)]
struct VisionTransformer {
embeddings: VisionEmbeddings,
encoder: Encoder,
post_layernorm: LayerNorm,
// head: Option<MultiheadAttentionPoolingHead>,
}

impl VisionTransformer {
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?;
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
let post_layernorm =
layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?;
Ok(Self {
embeddings,
encoder,
post_layernorm,
})
}
}

#[derive(Debug, Clone)]
pub struct VisionModel {
vision_model: VisionTransformer,
}

impl VisionModel {
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
let vision_model = VisionTransformer::new(cfg, vb.pp("vision_model"))?;
Ok(Self { vision_model })
}
}

0 comments on commit 657bf64

Please sign in to comment.