diff --git a/candle-examples/examples/clip/main.rs b/candle-examples/examples/clip/main.rs index d057663d5f..273edb6a0a 100644 --- a/candle-examples/examples/clip/main.rs +++ b/candle-examples/examples/clip/main.rs @@ -12,7 +12,6 @@ use candle_nn::{ops::softmax, VarBuilder}; use candle_transformers::models::clip; use tokenizers::Tokenizer; -use tracing::info; #[derive(Parser)] struct Args { @@ -40,15 +39,12 @@ fn load_image>(path: T, image_size: usize) -> anyhow:: height as u32, image::imageops::FilterType::Triangle, ); - let img = img.to_rgb8(); - let img = img.into_raw(); let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)? .permute((2, 0, 1))? .to_dtype(DType::F32)? .affine(2. / 255., -1.)?; - // .unsqueeze(0)?; Ok(img) } @@ -57,24 +53,16 @@ fn load_images>( image_size: usize, ) -> anyhow::Result { let mut images = vec![]; - for path in paths { let tensor = load_image(path, image_size)?; images.push(tensor); } - let images = Tensor::stack(&images, 0)?; - Ok(images) } pub fn main() -> anyhow::Result<()> { - // std::env::set_var("RUST_BACKTRACE", "full"); - let args = Args::parse(); - - tracing_subscriber::fmt::init(); - let model_file = match args.model { None => { let api = hf_hub::api::sync::Api::new()?; @@ -89,13 +77,9 @@ pub fn main() -> anyhow::Result<()> { } Some(model) => model.into(), }; - let tokenizer = get_tokenizer(args.tokenizer)?; - let config = clip::ClipConfig::vit_base_patch32(); - let device = candle_examples::device(args.cpu)?; - let vec_imgs = match args.images { Some(imgs) => imgs, None => vec![ @@ -103,43 +87,29 @@ pub fn main() -> anyhow::Result<()> { "candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(), ], }; - - // let image = load_image(args.image, config.image_size)?.to_device(&device)?; let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? }; - let model = clip::ClipModel::new(vb, &config)?; - let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?; - let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?; - let softmax_image = softmax(&logits_per_image, 1)?; - let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::()?; - - info!("softmax_image_vec: {:?}", softmax_image_vec); - + println!("softmax_image_vec: {:?}", softmax_image_vec); let probability_vec = softmax_image_vec .iter() .map(|v| v * 100.0) .collect::>(); - let probability_per_image = probability_vec.len() / vec_imgs.len(); - for (i, img) in vec_imgs.iter().enumerate() { let start = i * probability_per_image; let end = start + probability_per_image; let prob = &probability_vec[start..end]; - info!("\n\nResults for image: {}\n", img); - + println!("\n\nResults for image: {}\n", img); for (i, p) in prob.iter().enumerate() { - info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]); + println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]); } } - Ok(()) } @@ -156,7 +126,6 @@ pub fn get_tokenizer(tokenizer: Option) -> anyhow::Result { } Some(file) => file.into(), }; - Tokenizer::from_file(tokenizer).map_err(E::msg) } @@ -169,7 +138,6 @@ pub fn tokenize_sequences( .get_vocab(true) .get("<|endoftext|>") .ok_or(E::msg("No pad token"))?; - let vec_seq = match sequences { Some(seq) => seq, None => vec![ @@ -178,16 +146,12 @@ pub fn tokenize_sequences( "a robot holding a candle".to_string(), ], }; - let mut tokens = vec![]; - for seq in vec_seq.clone() { let encoding = tokenizer.encode(seq, true).map_err(E::msg)?; tokens.push(encoding.get_ids().to_vec()); } - let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0); - // Pad the sequences to have the same length for token_vec in tokens.iter_mut() { let len_diff = max_len - token_vec.len(); @@ -195,8 +159,6 @@ pub fn tokenize_sequences( token_vec.extend(vec![pad_id; len_diff]); } } - let input_ids = Tensor::new(tokens, device)?; - Ok((input_ids, vec_seq)) } diff --git a/candle-examples/examples/siglip/README.md b/candle-examples/examples/siglip/README.md new file mode 100644 index 0000000000..d79ae33062 --- /dev/null +++ b/candle-examples/examples/siglip/README.md @@ -0,0 +1,24 @@ +## SigLIP + +SigLIP is multi-modal text-vision model that improves over CLIP by using a sigmoid based loss, +[HuggingFace](https://huggingface.co/google/siglip-base-patch16-224). + +### Running an example +``` +$ cargo run --features cuda -r --example siglip - +softmax_image_vec: [2.1912122e-14, 2.3624872e-14, 1.0, 1.0, 2.4787932e-8, 3.2784535e-12] + + +Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg + +Probability: 0.0000% Text: a cycling race +Probability: 0.0000% Text: a photo of two cats +Probability: 100.0000% Text: a robot holding a candle + + +Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg + +Probability: 100.0000% Text: a cycling race +Probability: 0.0000% Text: a photo of two cats +Probability: 0.0000% Text: a robot holding a candle +``` diff --git a/candle-examples/examples/siglip/main.rs b/candle-examples/examples/siglip/main.rs new file mode 100644 index 0000000000..be953c8764 --- /dev/null +++ b/candle-examples/examples/siglip/main.rs @@ -0,0 +1,153 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Error as E; +use clap::Parser; + +use candle::{DType, Device, Tensor}; +use candle_nn::{ops::softmax, VarBuilder}; +use candle_transformers::models::siglip; + +use tokenizers::Tokenizer; + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option, + + #[arg(long)] + tokenizer: Option, + + #[arg(long, use_value_delimiter = true)] + images: Option>, + + #[arg(long)] + cpu: bool, + + #[arg(long, use_value_delimiter = true)] + sequences: Option>, +} + +fn load_image>(path: T, image_size: usize) -> anyhow::Result { + let img = image::ImageReader::open(path)?.decode()?; + let (height, width) = (image_size, image_size); + let img = img.resize_to_fill( + width as u32, + height as u32, + image::imageops::FilterType::Triangle, + ); + let img = img.to_rgb8(); + let img = img.into_raw(); + let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)? + .permute((2, 0, 1))? + .to_dtype(DType::F32)? + .affine(2. / 255., -1.)?; + Ok(img) +} + +fn load_images>( + paths: &Vec, + image_size: usize, +) -> anyhow::Result { + let mut images = vec![]; + for path in paths { + let tensor = load_image(path, image_size)?; + images.push(tensor); + } + let images = Tensor::stack(&images, 0)?; + Ok(images) +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + let model_file = match args.model { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("google/siglip-base-patch16-224".to_string()); + api.get("model.safetensors")? + } + Some(model) => model.into(), + }; + let tokenizer = get_tokenizer(args.tokenizer)?; + let config = siglip::Config::base_patch16_224(); + let device = candle_examples::device(args.cpu)?; + let vec_imgs = match args.images { + Some(imgs) => imgs, + None => vec![ + "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(), + "candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(), + ], + }; + let images = load_images(&vec_imgs, config.vision_config.image_size)?.to_device(&device)?; + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? }; + let model = siglip::Model::new(&config, vb)?; + let (input_ids, vec_seq) = tokenize_sequences(&config, args.sequences, &tokenizer, &device)?; + let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?; + let softmax_image = softmax(&logits_per_image, 1)?; + let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::()?; + println!("softmax_image_vec: {:?}", softmax_image_vec); + let probability_vec = softmax_image_vec + .iter() + .map(|v| v * 100.0) + .collect::>(); + let probability_per_image = probability_vec.len() / vec_imgs.len(); + for (i, img) in vec_imgs.iter().enumerate() { + let start = i * probability_per_image; + let end = start + probability_per_image; + let prob = &probability_vec[start..end]; + println!("\n\nResults for image: {}\n", img); + for (i, p) in prob.iter().enumerate() { + println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]); + } + } + Ok(()) +} + +pub fn get_tokenizer(tokenizer: Option) -> anyhow::Result { + let tokenizer = match tokenizer { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("google/siglip-base-patch16-224".to_string()); + api.get("tokenizer.json")? + } + Some(file) => file.into(), + }; + + Tokenizer::from_file(tokenizer).map_err(E::msg) +} + +pub fn tokenize_sequences( + config: &siglip::Config, + sequences: Option>, + tokenizer: &Tokenizer, + device: &Device, +) -> anyhow::Result<(Tensor, Vec)> { + let pad_id = config.text_config.pad_token_id; + let vec_seq = match sequences { + Some(seq) => seq, + None => vec![ + "a cycling race".to_string(), + "a photo of two cats".to_string(), + "a robot holding a candle".to_string(), + ], + }; + let mut tokens = vec![]; + for seq in vec_seq.clone() { + let encoding = tokenizer.encode(seq, true).map_err(E::msg)?; + tokens.push(encoding.get_ids().to_vec()); + } + let max_len = config.text_config.max_position_embeddings; + // Pad the sequences to have the same length + for token_vec in tokens.iter_mut() { + let len_diff = max_len - token_vec.len(); + if len_diff > 0 { + token_vec.extend(vec![pad_id; len_diff]); + } + } + let input_ids = Tensor::new(tokens, device)?; + Ok((input_ids, vec_seq)) +} diff --git a/candle-transformers/src/models/clip/mod.rs b/candle-transformers/src/models/clip/mod.rs index 9613fdab0e..3dd5fb485b 100644 --- a/candle-transformers/src/models/clip/mod.rs +++ b/candle-transformers/src/models/clip/mod.rs @@ -92,28 +92,23 @@ impl ClipConfig { impl ClipModel { pub fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result { let text_model = ClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?; - let vision_model = ClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?; - let visual_projection = candle_nn::linear_no_bias( c.vision_config.embed_dim, c.vision_config.projection_dim, vs.pp("visual_projection"), )?; - let text_projection = candle_nn::linear_no_bias( c.text_config.embed_dim, c.text_config.projection_dim, vs.pp("text_projection"), )?; - // originally nn.Parameter let logit_scale = if vs.contains_tensor("logit_scale") { vs.get(&[], "logit_scale")? } else { Tensor::new(&[c.logit_scale_init_value], vs.device())? }; - Ok(Self { text_model, vision_model, diff --git a/candle-transformers/src/models/clip/text_model.rs b/candle-transformers/src/models/clip/text_model.rs index 51db14ee0c..4662f65fda 100644 --- a/candle-transformers/src/models/clip/text_model.rs +++ b/candle-transformers/src/models/clip/text_model.rs @@ -77,7 +77,7 @@ impl ClipTextEmbeddings { )?; let position_ids = Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?; - Ok(ClipTextEmbeddings { + Ok(Self { token_embedding, position_embedding, position_ids, @@ -298,7 +298,7 @@ impl ClipTextTransformer { }) } - // TODO: rewrrite to newer version + // TODO: rewrite to newer version fn build_causal_attention_mask( bsz: usize, seq_len: usize, diff --git a/candle-transformers/src/models/fastvit.rs b/candle-transformers/src/models/fastvit.rs index b7bdaf888a..8eae8bb200 100644 --- a/candle-transformers/src/models/fastvit.rs +++ b/candle-transformers/src/models/fastvit.rs @@ -11,13 +11,13 @@ use candle_nn::{ BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder, }; -#[derive(Clone, Debug)] +#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] pub struct Config { - exp_ratio: usize, - in_channels: usize, - blocks: [usize; 4], - attn: bool, - lkc_use_act: bool, + pub exp_ratio: usize, + pub in_channels: usize, + pub blocks: [usize; 4], + pub attn: bool, + pub lkc_use_act: bool, } impl Config { diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index dae41a807c..a0e7a9225b 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -76,6 +76,7 @@ pub mod rwkv_v5; pub mod rwkv_v6; pub mod segformer; pub mod segment_anything; +pub mod siglip; pub mod stable_diffusion; pub mod stable_lm; pub mod starcoder2; diff --git a/candle-transformers/src/models/siglip.rs b/candle-transformers/src/models/siglip.rs new file mode 100644 index 0000000000..a3280a86fc --- /dev/null +++ b/candle-transformers/src/models/siglip.rs @@ -0,0 +1,608 @@ +use crate::models::clip::div_l2_norm; +use candle::{IndexOp, Module, Result, Tensor, D}; +use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder}; + +// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L27 +#[derive(serde::Deserialize, Clone, Debug)] +pub struct TextConfig { + pub vocab_size: usize, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub max_position_embeddings: usize, + pub hidden_act: candle_nn::Activation, + pub layer_norm_eps: f64, + pub pad_token_id: u32, + pub bos_token_id: u32, + pub eos_token_id: u32, +} + +// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L132 +#[derive(serde::Deserialize, Clone, Debug)] +pub struct VisionConfig { + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_channels: usize, + pub image_size: usize, + pub patch_size: usize, + pub hidden_act: candle_nn::Activation, + pub layer_norm_eps: f64, +} + +trait TransformerConfig { + fn hidden_size(&self) -> usize; + fn intermediate_size(&self) -> usize; + fn num_attention_heads(&self) -> usize; + fn num_hidden_layers(&self) -> usize; + fn layer_norm_eps(&self) -> f64; + fn hidden_act(&self) -> candle_nn::Activation; +} + +impl TransformerConfig for TextConfig { + fn hidden_size(&self) -> usize { + self.hidden_size + } + fn intermediate_size(&self) -> usize { + self.intermediate_size + } + fn num_attention_heads(&self) -> usize { + self.num_attention_heads + } + fn num_hidden_layers(&self) -> usize { + self.num_hidden_layers + } + fn layer_norm_eps(&self) -> f64 { + self.layer_norm_eps + } + fn hidden_act(&self) -> candle_nn::Activation { + self.hidden_act + } +} + +impl TransformerConfig for VisionConfig { + fn hidden_size(&self) -> usize { + self.hidden_size + } + fn intermediate_size(&self) -> usize { + self.intermediate_size + } + fn num_attention_heads(&self) -> usize { + self.num_attention_heads + } + fn num_hidden_layers(&self) -> usize { + self.num_hidden_layers + } + fn layer_norm_eps(&self) -> f64 { + self.layer_norm_eps + } + fn hidden_act(&self) -> candle_nn::Activation { + self.hidden_act + } +} + +// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L228 +#[derive(serde::Deserialize, Clone, Debug)] +pub struct Config { + pub text_config: TextConfig, + pub vision_config: VisionConfig, +} + +impl Config { + pub fn base_patch16_224() -> Self { + let text_config = TextConfig { + // https://huggingface.co/google/siglip-base-patch16-224/blob/main/config.json + hidden_size: 768, + intermediate_size: 3072, + num_attention_heads: 12, + vocab_size: 32000, + // Default values. + pad_token_id: 1, + bos_token_id: 49406, + eos_token_id: 49407, + layer_norm_eps: 1e-6, + hidden_act: candle_nn::Activation::GeluPytorchTanh, + max_position_embeddings: 64, + num_hidden_layers: 12, + }; + let vision_config = VisionConfig { + patch_size: 16, + // Default values. + hidden_size: 768, + intermediate_size: 3072, + num_hidden_layers: 12, + num_attention_heads: 12, + num_channels: 3, + image_size: 224, + hidden_act: candle_nn::Activation::GeluPytorchTanh, + layer_norm_eps: 1e-6, + }; + Self { + text_config, + vision_config, + } + } +} + +#[derive(Clone, Debug)] +struct MultiheadAttention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + out_proj: Linear, + num_heads: usize, +} + +impl MultiheadAttention { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let h = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let w_in_proj = vb.get((3 * h, h), "in_proj_weight")?.chunk(3, 0)?; + let b_in_proj = vb.get(3 * h, "in_proj_bias")?.chunk(3, 0)?; + let q_proj = Linear::new(w_in_proj[0].clone(), Some(b_in_proj[0].clone())); + let k_proj = Linear::new(w_in_proj[1].clone(), Some(b_in_proj[1].clone())); + let v_proj = Linear::new(w_in_proj[2].clone(), Some(b_in_proj[2].clone())); + let out_proj = linear(h, h, vb.pp("out_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + out_proj, + num_heads, + }) + } + + fn separate_heads(&self, x: &Tensor) -> Result { + let (b, n, c) = x.dims3()?; + x.reshape((b, n, self.num_heads, c / self.num_heads))? + .transpose(1, 2)? + .contiguous() + } + + fn recombine_heads(&self, x: &Tensor) -> Result { + let (b, n_heads, n_tokens, c_per_head) = x.dims4()?; + x.transpose(1, 2)? + .reshape((b, n_tokens, n_heads * c_per_head)) + } + + fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result { + let q = self.q_proj.forward(&q.contiguous()?)?; + let k = self.k_proj.forward(&k.contiguous()?)?; + let v = self.v_proj.forward(&v.contiguous()?)?; + + let q = self.separate_heads(&q)?; + let k = self.separate_heads(&k)?; + let v = self.separate_heads(&v)?; + + let (_, _, _, c_per_head) = q.dims4()?; + let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?; + let attn = candle_nn::ops::softmax_last_dim(&attn)?; + + let out = attn.matmul(&v)?; + self.recombine_heads(&out)?.apply(&self.out_proj) + } +} + +#[derive(Debug, Clone)] +struct MultiheadAttentionPoolingHead { + probe: Tensor, + attention: MultiheadAttention, + layernorm: LayerNorm, + mlp: Mlp, +} + +impl MultiheadAttentionPoolingHead { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let mlp = Mlp::new(cfg, vb.pp("mlp"))?; + let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layernorm"))?; + let probe = vb.get((1, 1, cfg.hidden_size), "probe")?; + let attention = MultiheadAttention::new(cfg, vb.pp("attention"))?; + Ok(Self { + probe, + attention, + layernorm, + mlp, + }) + } +} + +impl Module for MultiheadAttentionPoolingHead { + fn forward(&self, xs: &Tensor) -> Result { + let batch_size = xs.dim(0)?; + let probe = self.probe.repeat((batch_size, 1, 1))?; + let xs = self.attention.forward(&probe, xs, xs)?; + let residual = &xs; + let xs = xs.apply(&self.layernorm)?.apply(&self.mlp)?; + (xs + residual)?.i((.., 0)) + } +} + +#[derive(Debug, Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + out_proj: Linear, + num_heads: usize, + head_dim: usize, + scale: f64, +} + +impl Attention { + fn new(cfg: &C, vb: VarBuilder) -> Result { + let embed_dim = cfg.hidden_size(); + let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?; + let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?; + let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?; + let out_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?; + let num_heads = cfg.num_attention_heads(); + let head_dim = embed_dim / num_heads; + Ok(Self { + q_proj, + k_proj, + v_proj, + out_proj, + num_heads, + head_dim, + scale: (head_dim as f64).powf(-0.5), + }) + } + + fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { + let (batch_size, q_len, _) = xs.dims3()?; + let query_states = xs.apply(&self.q_proj)?; + let key_states = xs.apply(&self.k_proj)?; + let value_states = xs.apply(&self.v_proj)?; + + let shape = (batch_size, q_len, self.num_heads, self.head_dim); + let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?; + let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?; + let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?; + + let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?; + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + // The original implementation upcasts to f32 but candle_nn::ops::softmax should handle this properly. + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_outputs = attn_weights + .matmul(&value_states)? + .transpose(1, 2)? + .reshape((batch_size, q_len, ()))? + .apply(&self.out_proj)?; + Ok(attn_outputs) + } +} + +// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/modeling_siglip.py#L599 +#[derive(Debug, Clone)] +struct Mlp { + fc1: Linear, + fc2: Linear, + activation_fn: candle_nn::Activation, +} + +impl Mlp { + fn new(cfg: &C, vb: VarBuilder) -> Result { + let hidden_size = cfg.hidden_size(); + let intermediate_size = cfg.intermediate_size(); + let fc1 = candle_nn::linear(hidden_size, intermediate_size, vb.pp("fc1"))?; + let fc2 = candle_nn::linear(intermediate_size, hidden_size, vb.pp("fc2"))?; + Ok(Self { + fc1, + fc2, + activation_fn: cfg.hidden_act(), + }) + } +} + +impl Module for Mlp { + fn forward(&self, xs: &candle::Tensor) -> Result { + xs.apply(&self.fc1)? + .apply(&self.activation_fn)? + .apply(&self.fc2) + } +} + +// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/modeling_siglip.py#L614 +#[derive(Debug, Clone)] +struct EncoderLayer { + self_attn: Attention, + layer_norm1: LayerNorm, + mlp: Mlp, + layer_norm2: LayerNorm, +} + +impl EncoderLayer { + fn new(cfg: &C, vb: VarBuilder) -> Result { + let hidden_size = cfg.hidden_size(); + let layer_norm_eps = cfg.layer_norm_eps(); + let self_attn = Attention::new(cfg, vb.pp("self_attn"))?; + let layer_norm1 = layer_norm(hidden_size, layer_norm_eps, vb.pp("layer_norm1"))?; + let mlp = Mlp::new(cfg, vb.pp("mlp"))?; + let layer_norm2 = layer_norm(hidden_size, layer_norm_eps, vb.pp("layer_norm2"))?; + Ok(Self { + self_attn, + layer_norm1, + mlp, + layer_norm2, + }) + } + + fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { + let residual = xs; + let xs = xs.apply(&self.layer_norm1)?; + let xs = self.self_attn.forward(&xs, attention_mask)?; + let xs = (residual + xs)?; + let residual = &xs; + let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?; + let xs = (xs + residual)?; + Ok(xs) + } +} + +#[derive(Debug, Clone)] +struct Encoder { + layers: Vec, +} + +impl Encoder { + fn new(cfg: &C, vb: VarBuilder) -> Result { + let mut layers = vec![]; + let vb = vb.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers() { + let layer = EncoderLayer::new(cfg, vb.pp(layer_idx))?; + layers.push(layer) + } + Ok(Self { layers }) + } + + fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs, attention_mask)? + } + Ok(xs) + } +} + +#[derive(Debug, Clone)] +struct VisionEmbeddings { + patch_embedding: candle_nn::Conv2d, + position_embedding: candle_nn::Embedding, + position_ids: Tensor, +} + +impl VisionEmbeddings { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let conv2d_cfg = candle_nn::Conv2dConfig { + stride: cfg.patch_size, + ..Default::default() + }; + let patch_embedding = candle_nn::conv2d( + cfg.num_channels, + cfg.hidden_size, + cfg.patch_size, + conv2d_cfg, + vb.pp("patch_embedding"), + )?; + let num_patches = (cfg.image_size / cfg.patch_size).pow(2); + let position_ids = Tensor::arange(0, num_patches as i64, vb.device())?; + let position_embedding = + candle_nn::embedding(num_patches, cfg.hidden_size(), vb.pp("position_embedding"))?; + Ok(Self { + patch_embedding, + position_embedding, + position_ids, + }) + } +} + +impl Module for VisionEmbeddings { + fn forward(&self, xs: &Tensor) -> Result { + let (_batch, _channels, _height, _width) = xs.dims4()?; + let embeddings = xs.apply(&self.patch_embedding)?; + let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?; + let position_embedding = self.position_embedding.forward(&self.position_ids)?; + embeddings.broadcast_add(&position_embedding) + } +} + +#[derive(Debug, Clone)] +struct VisionTransformer { + embeddings: VisionEmbeddings, + encoder: Encoder, + post_layernorm: LayerNorm, + head: Option, +} + +impl VisionTransformer { + fn new(cfg: &VisionConfig, use_head: bool, vb: VarBuilder) -> Result { + let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?; + let encoder = Encoder::new(cfg, vb.pp("encoder"))?; + let post_layernorm = + layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?; + let head = if use_head { + Some(MultiheadAttentionPoolingHead::new(cfg, vb.pp("head"))?) + } else { + None + }; + Ok(Self { + embeddings, + encoder, + post_layernorm, + head, + }) + } +} + +impl Module for VisionTransformer { + fn forward(&self, xs: &Tensor) -> Result { + let xs = xs.apply(&self.embeddings)?; + let xs = self.encoder.forward(&xs, None)?; + let xs = xs.apply(&self.post_layernorm)?; + match self.head.as_ref() { + None => Ok(xs), + Some(h) => xs.apply(h), + } + } +} + +#[derive(Debug, Clone)] +pub struct VisionModel { + vision_model: VisionTransformer, +} + +impl VisionModel { + pub fn new(cfg: &VisionConfig, use_head: bool, vb: VarBuilder) -> Result { + let vision_model = VisionTransformer::new(cfg, use_head, vb)?; + Ok(Self { vision_model }) + } +} + +impl Module for VisionModel { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.vision_model) + } +} + +#[derive(Debug, Clone)] +struct TextEmbeddings { + token_embedding: candle_nn::Embedding, + position_embedding: candle_nn::Embedding, + position_ids: Tensor, +} + +impl TextEmbeddings { + fn new(cfg: &TextConfig, vb: VarBuilder) -> Result { + let token_embedding = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("token_embedding"))?; + let position_embedding = candle_nn::embedding( + cfg.max_position_embeddings, + cfg.hidden_size, + vb.pp("position_embedding"), + )?; + let position_ids = + Tensor::arange(0u32, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?; + Ok(Self { + token_embedding, + position_embedding, + position_ids, + }) + } +} + +impl Module for TextEmbeddings { + fn forward(&self, input_ids: &Tensor) -> Result { + let seq_length = input_ids.dim(D::Minus1)?; + let inputs_embeds = self.token_embedding.forward(input_ids)?; + let position_ids = self.position_ids.narrow(1, 0, seq_length)?; + let position_embedding = self.position_embedding.forward(&position_ids)?; + inputs_embeds.broadcast_add(&position_embedding) + } +} + +#[derive(Debug, Clone)] +pub struct TextTransformer { + embeddings: TextEmbeddings, + encoder: Encoder, + final_layer_norm: LayerNorm, + pub head: Linear, +} + +impl TextTransformer { + fn new(cfg: &TextConfig, vb: VarBuilder) -> Result { + let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?; + let encoder = Encoder::new(cfg, vb.pp("encoder"))?; + let final_layer_norm = layer_norm( + cfg.hidden_size, + cfg.layer_norm_eps, + vb.pp("final_layer_norm"), + )?; + let head = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("head"))?; + Ok(Self { + embeddings, + encoder, + final_layer_norm, + head, + }) + } +} +impl Module for TextTransformer { + fn forward(&self, input_ids: &Tensor) -> Result { + let (_bsz, seq_len) = input_ids.dims2()?; + let input_ids = self.embeddings.forward(input_ids)?; + let input_ids = self.encoder.forward(&input_ids, None)?; + let last_hidden_state = self.final_layer_norm.forward(&input_ids)?; + last_hidden_state + .i((.., seq_len - 1, ..))? + .contiguous()? + .apply(&self.head) + } +} + +#[derive(Debug, Clone)] +pub struct TextModel { + pub text_model: TextTransformer, +} + +impl TextModel { + pub fn new(cfg: &TextConfig, vb: VarBuilder) -> Result { + let text_model = TextTransformer::new(cfg, vb)?; + Ok(Self { text_model }) + } +} + +impl Module for TextModel { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.text_model) + } +} + +#[derive(Clone, Debug)] +pub struct Model { + text_model: TextModel, + vision_model: VisionModel, + logit_bias: Tensor, + logit_scale: Tensor, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let text_model = TextModel::new(&cfg.text_config, vb.pp("text_model"))?; + let vision_model = VisionModel::new(&cfg.vision_config, true, vb.pp("vision_model"))?; + let logit_scale = vb.get(&[1], "logit_scale")?; + let logit_bias = vb.get(&[1], "logit_bias")?; + Ok(Self { + text_model, + vision_model, + logit_bias, + logit_scale, + }) + } + + pub fn get_text_features(&self, input_ids: &Tensor) -> Result { + input_ids.apply(&self.text_model) + } + + pub fn get_image_features(&self, pixel_values: &Tensor) -> Result { + pixel_values.apply(&self.vision_model) + } + + pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> { + let image_features = self.get_image_features(pixel_values)?; + let text_features = self.get_text_features(input_ids)?; + let image_features_normalized = div_l2_norm(&image_features)?; + let text_features_normalized = div_l2_norm(&text_features)?; + let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?; + let logit_scale = self.logit_scale.exp()?; + let logits_per_text = logits_per_text + .broadcast_mul(&logit_scale)? + .broadcast_add(&self.logit_bias)?; + let logits_per_image = logits_per_text.t()?; + Ok((logits_per_text, logits_per_image)) + } +}