diff --git a/candle-transformers/src/models/clip/text_model.rs b/candle-transformers/src/models/clip/text_model.rs index 51db14ee0c..3d0530a985 100644 --- a/candle-transformers/src/models/clip/text_model.rs +++ b/candle-transformers/src/models/clip/text_model.rs @@ -77,7 +77,7 @@ impl ClipTextEmbeddings { )?; let position_ids = Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?; - Ok(ClipTextEmbeddings { + Ok(Self { token_embedding, position_embedding, position_ids, diff --git a/candle-transformers/src/models/siglip.rs b/candle-transformers/src/models/siglip.rs index 315f331a30..61028d1a68 100644 --- a/candle-transformers/src/models/siglip.rs +++ b/candle-transformers/src/models/siglip.rs @@ -1,5 +1,5 @@ #![allow(unused)] -use candle::{Result, Tensor}; +use candle::{Result, Tensor, D}; use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder}; // https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L27 @@ -151,6 +151,7 @@ impl Attention { } } +// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/modeling_siglip.py#L599 #[derive(Debug, Clone)] struct Mlp { fc1: Linear, @@ -181,6 +182,7 @@ impl candle::Module for Mlp { } } +// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/modeling_siglip.py#L614 #[derive(Debug, Clone)] struct EncoderLayer { self_attn: Attention, @@ -273,6 +275,18 @@ impl VisionEmbeddings { } } +impl candle::Module for VisionEmbeddings { + fn forward(&self, xs: &Tensor) -> Result { + let (_, _, height, width) = xs.dims4()?; + let embeddings = xs + .apply(&self.patch_embedding)? + .flatten_from(2)? + .transpose(1, 2)?; + let position_embedding = self.position_embedding.forward(&self.position_ids)?; + embeddings.broadcast_add(&position_embedding) + } +} + #[derive(Debug, Clone)] struct VisionTransformer { embeddings: VisionEmbeddings, @@ -295,6 +309,14 @@ impl VisionTransformer { } } +impl candle::Module for VisionTransformer { + fn forward(&self, xs: &Tensor) -> Result { + let xs = xs.apply(&self.embeddings)?; + let xs = self.encoder.forward(&xs, None)?; + xs.apply(&self.post_layernorm) + } +} + #[derive(Debug, Clone)] pub struct VisionModel { vision_model: VisionTransformer, @@ -306,3 +328,84 @@ impl VisionModel { Ok(Self { vision_model }) } } + +impl candle::Module for VisionModel { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.vision_model) + } +} + +#[derive(Debug, Clone)] +struct TextEmbeddings { + token_embedding: candle_nn::Embedding, + position_embedding: candle_nn::Embedding, + position_ids: Tensor, +} + +impl TextEmbeddings { + fn new(cfg: &TextConfig, vb: VarBuilder) -> Result { + let token_embedding = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("token_embedding"))?; + let position_embedding = candle_nn::embedding( + cfg.max_position_embeddings, + cfg.hidden_size, + vb.pp("position_embedding"), + )?; + let position_ids = + Tensor::arange(0u32, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?; + Ok(Self { + token_embedding, + position_embedding, + position_ids, + }) + } +} + +impl candle::Module for TextEmbeddings { + fn forward(&self, input_ids: &Tensor) -> Result { + let seq_length = input_ids.dim(D::Minus1)?; + let inputs_embeds = self.token_embedding.forward(input_ids)?; + let position_ids = self.position_ids.narrow(1, 0, seq_length)?; + let position_embedding = self.position_embedding.forward(&position_ids)?; + inputs_embeds.broadcast_add(&position_embedding) + } +} + +#[derive(Debug, Clone)] +struct TextTransformer { + embeddings: TextEmbeddings, + encoder: Encoder, + final_layer_norm: LayerNorm, + head: Linear, +} + +impl TextTransformer { + fn new(cfg: &TextConfig, vb: VarBuilder) -> Result { + let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?; + let encoder = Encoder::new(cfg, vb.pp("encoder"))?; + let final_layer_norm = layer_norm( + cfg.hidden_size, + cfg.layer_norm_eps, + vb.pp("final_layer_norm"), + )?; + let head = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("head"))?; + Ok(Self { + embeddings, + encoder, + final_layer_norm, + head, + }) + } +} + +#[derive(Debug, Clone)] +pub struct TextModel { + text_model: TextTransformer, +} + +impl TextModel { + fn new(cfg: &TextConfig, vb: VarBuilder) -> Result { + let text_model = TextTransformer::new(cfg, vb.pp("text_model"))?; + Ok(Self { text_model }) + } +}