Skip to content

Commit

Permalink
Add more to the forward pass of the vision model.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 28, 2024
1 parent 657bf64 commit 1897c32
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 2 deletions.
2 changes: 1 addition & 1 deletion candle-transformers/src/models/clip/text_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl ClipTextEmbeddings {
)?;
let position_ids =
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
Ok(ClipTextEmbeddings {
Ok(Self {
token_embedding,
position_embedding,
position_ids,
Expand Down
105 changes: 104 additions & 1 deletion candle-transformers/src/models/siglip.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#![allow(unused)]
use candle::{Result, Tensor};
use candle::{Result, Tensor, D};
use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder};

// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L27
Expand Down Expand Up @@ -151,6 +151,7 @@ impl Attention {
}
}

// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/modeling_siglip.py#L599
#[derive(Debug, Clone)]
struct Mlp {
fc1: Linear,
Expand Down Expand Up @@ -181,6 +182,7 @@ impl candle::Module for Mlp {
}
}

// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/modeling_siglip.py#L614
#[derive(Debug, Clone)]
struct EncoderLayer {
self_attn: Attention,
Expand Down Expand Up @@ -273,6 +275,18 @@ impl VisionEmbeddings {
}
}

impl candle::Module for VisionEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (_, _, height, width) = xs.dims4()?;
let embeddings = xs
.apply(&self.patch_embedding)?
.flatten_from(2)?
.transpose(1, 2)?;
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
embeddings.broadcast_add(&position_embedding)
}
}

#[derive(Debug, Clone)]
struct VisionTransformer {
embeddings: VisionEmbeddings,
Expand All @@ -295,6 +309,14 @@ impl VisionTransformer {
}
}

impl candle::Module for VisionTransformer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.embeddings)?;
let xs = self.encoder.forward(&xs, None)?;
xs.apply(&self.post_layernorm)
}
}

#[derive(Debug, Clone)]
pub struct VisionModel {
vision_model: VisionTransformer,
Expand All @@ -306,3 +328,84 @@ impl VisionModel {
Ok(Self { vision_model })
}
}

impl candle::Module for VisionModel {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.vision_model)
}
}

#[derive(Debug, Clone)]
struct TextEmbeddings {
token_embedding: candle_nn::Embedding,
position_embedding: candle_nn::Embedding,
position_ids: Tensor,
}

impl TextEmbeddings {
fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
let token_embedding =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("token_embedding"))?;
let position_embedding = candle_nn::embedding(
cfg.max_position_embeddings,
cfg.hidden_size,
vb.pp("position_embedding"),
)?;
let position_ids =
Tensor::arange(0u32, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?;
Ok(Self {
token_embedding,
position_embedding,
position_ids,
})
}
}

impl candle::Module for TextEmbeddings {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let seq_length = input_ids.dim(D::Minus1)?;
let inputs_embeds = self.token_embedding.forward(input_ids)?;
let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
let position_embedding = self.position_embedding.forward(&position_ids)?;
inputs_embeds.broadcast_add(&position_embedding)
}
}

#[derive(Debug, Clone)]
struct TextTransformer {
embeddings: TextEmbeddings,
encoder: Encoder,
final_layer_norm: LayerNorm,
head: Linear,
}

impl TextTransformer {
fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?;
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
let final_layer_norm = layer_norm(
cfg.hidden_size,
cfg.layer_norm_eps,
vb.pp("final_layer_norm"),
)?;
let head = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("head"))?;
Ok(Self {
embeddings,
encoder,
final_layer_norm,
head,
})
}
}

#[derive(Debug, Clone)]
pub struct TextModel {
text_model: TextTransformer,
}

impl TextModel {
fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
let text_model = TextTransformer::new(cfg, vb.pp("text_model"))?;
Ok(Self { text_model })
}
}

0 comments on commit 1897c32

Please sign in to comment.