Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to use any CLIPModel with model-id and revision #2527

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion candle-examples/examples/bert/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl Args {
};
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let config: String = std::fs::read_to_string(config_filename)?;
let mut config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

Expand Down
31 changes: 23 additions & 8 deletions candle-examples/examples/clip/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,34 @@ https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585
$ cargo run --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"


Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
> Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg

INFO clip: Probability: 0.0000% Text: a cycling race
INFO clip: Probability: 0.0000% Text: a photo of two cats
INFO clip: Probability: 100.0000% Text: a robot holding a candle
> INFO clip: Probability: 0.0000% Text: a cycling race
> INFO clip: Probability: 0.0000% Text: a photo of two cats
> INFO clip: Probability: 100.0000% Text: a robot holding a candle

Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
> Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg

INFO clip: Probability: 99.9999% Text: a cycling race
INFO clip: Probability: 0.0001% Text: a photo of two cats
INFO clip: Probability: 0.0000% Text: a robot holding a candle
> INFO clip: Probability: 99.9999% Text: a cycling race
> INFO clip: Probability: 0.0001% Text: a photo of two cats
> INFO clip: Probability: 0.0000% Text: a robot holding a candle
```

### Arguments
- `--model`: local path to the model. If not provided will download from huggingface
- `--tokenizer`: local path to the tokenizer.json file. If not provided will download from huggingface
- `--images`: list of images to use.

Example: `--images candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg,candle-examples/examples/yolo-v8/assets/bike.jpg`
- `sequences`: list of text sequences to use.

Example: `--sequences "a cycling race","bike"`

- `--model-id`: model id to use from huggingface. Example: `--model-id openai/clip-vit-large-patch14`
- `--revision`: revision to use from huggingface. Example: `--revision refs/pr/4`
- `--use-pth`: Use the pytorch weights rather than the safetensors ones. Default: true
- `--cpu`: Use cpu. Use `--cpu false` for gpu but requires gpu support with `--features cuda`

## Running on an example with metal feature (mac)

```
Expand Down
95 changes: 61 additions & 34 deletions candle-examples/examples/clip/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,42 @@ use clap::Parser;

use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::clip;
use candle_transformers::models::clip::{self, ClipConfig};

use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;

#[derive(Parser)]
struct Args {
/// model to use from local file. If not provided will download from huggingface
#[arg(long)]
model: Option<String>,

#[arg(long)]
tokenizer: Option<String>,

/// The images to use
#[arg(long, use_value_delimiter = true)]
images: Option<Vec<String>>,

#[arg(long)]
cpu: bool,

/// The sequences to use
#[arg(long, use_value_delimiter = true)]
sequences: Option<Vec<String>>,

/// The model_id to use from huggingface, check out available models:https://huggingface.co/openai
#[arg(long)]
model_id: Option<String>,

/// The revision to use from huggingface
#[arg(long)]
revision: Option<String>,

/// Use the pytorch weights rather than the safetensors ones
#[arg(long, default_value = "false")]
use_pth: bool,
}

fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
Expand Down Expand Up @@ -63,33 +79,60 @@ fn load_images<T: AsRef<std::path::Path>>(

pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;

let api = api.repo(hf_hub::Repo::with_revision(
"openai/clip-vit-base-patch32".to_string(),
hf_hub::RepoType::Model,
"refs/pr/15".to_string(),
));
tracing_subscriber::fmt::init();

api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let tokenizer = get_tokenizer(args.tokenizer)?;
let config = clip::ClipConfig::vit_base_patch32();
let device = candle_examples::device(args.cpu)?;
let default_model = "openai/clip-vit-base-patch16".to_string();
let default_revision = "refs/pr/4".to_string();
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};

let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let api = api.repo(repo);
let config = api.get("config.json")?;
let tokenizer = match args.tokenizer {
Some(tokenizer) => tokenizer.into(),
None => api.get("tokenizer.json")?,
};
let weights = match args.model {
Some(model) => model.into(),
None => {
if args.use_pth {
api.get("pytorch_model.bin")?
} else {
api.get("model.safetensors")?
}
}
};

(config, tokenizer, weights)
};

let config: String = std::fs::read_to_string(config_filename)?;
let config: ClipConfig = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };

let images = load_images(&vec_imgs, config.vision_config.image_size)?.to_device(&device)?;

let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_filename.clone()], DType::F32, &device)?
};

let model = clip::ClipModel::new(vb, &config)?;
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
Expand All @@ -113,22 +156,6 @@ pub fn main() -> anyhow::Result<()> {
Ok(())
}

pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
let tokenizer = match tokenizer {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.repo(hf_hub::Repo::with_revision(
"openai/clip-vit-base-patch32".to_string(),
hf_hub::RepoType::Model,
"refs/pr/15".to_string(),
));
api.get("tokenizer.json")?
}
Some(file) => file.into(),
};
Tokenizer::from_file(tokenizer).map_err(E::msg)
}

pub fn tokenize_sequences(
sequences: Option<Vec<String>>,
tokenizer: &Tokenizer,
Expand Down
4 changes: 2 additions & 2 deletions candle-examples/examples/flux/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ fn run(args: Args) -> Result<()> {
let config = clip::text_model::ClipTextConfig {
vocab_size: 49408,
projection_dim: 768,
activation: clip::text_model::Activation::QuickGelu,
hidden_act: clip::text_model::Activation::QuickGelu,
intermediate_size: 3072,
embed_dim: 768,
hidden_size: 768,
max_position_embeddings: 77,
pad_with: None,
num_hidden_layers: 12,
Expand Down
15 changes: 7 additions & 8 deletions candle-transformers/src/models/clip/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use self::{
vision_model::ClipVisionTransformer,
};
use candle::{Result, Tensor, D};
use serde::Deserialize;

pub mod text_model;
pub mod vision_model;
Expand All @@ -32,8 +33,8 @@ pub enum EncoderConfig {
impl EncoderConfig {
pub fn embed_dim(&self) -> usize {
match self {
Self::Text(c) => c.embed_dim,
Self::Vision(c) => c.embed_dim,
Self::Text(c) => c.hidden_size,
Self::Vision(c) => c.hidden_size,
}
}

Expand Down Expand Up @@ -61,17 +62,16 @@ impl EncoderConfig {
pub fn activation(&self) -> Activation {
match self {
Self::Text(_c) => Activation::QuickGelu,
Self::Vision(c) => c.activation,
Self::Vision(c) => c.hidden_act,
}
}
}

#[derive(Clone, Debug)]
#[derive(Clone, Debug, Deserialize)]
pub struct ClipConfig {
pub text_config: text_model::ClipTextConfig,
pub vision_config: vision_model::ClipVisionConfig,
pub logit_scale_init_value: f32,
pub image_size: usize,
}

impl ClipConfig {
Expand All @@ -84,7 +84,6 @@ impl ClipConfig {
text_config,
vision_config,
logit_scale_init_value: 2.6592,
image_size: 224,
}
}
}
Expand All @@ -94,12 +93,12 @@ impl ClipModel {
let text_model = ClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?;
let vision_model = ClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?;
let visual_projection = candle_nn::linear_no_bias(
c.vision_config.embed_dim,
c.vision_config.hidden_size,
c.vision_config.projection_dim,
vs.pp("visual_projection"),
)?;
let text_projection = candle_nn::linear_no_bias(
c.text_config.embed_dim,
c.text_config.hidden_size,
c.text_config.projection_dim,
vs.pp("text_projection"),
)?;
Expand Down
37 changes: 20 additions & 17 deletions candle-transformers/src/models/clip/text_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
use serde::Deserialize;

use super::EncoderConfig;

#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Activation {
QuickGelu,
}
Expand All @@ -25,11 +27,11 @@ impl Module for Activation {
}
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Deserialize)]
pub struct ClipTextConfig {
pub vocab_size: usize,
pub embed_dim: usize,
pub activation: Activation,
pub hidden_size: usize,
pub hidden_act: Activation,
pub intermediate_size: usize,
pub max_position_embeddings: usize,
pub pad_with: Option<String>,
Expand All @@ -45,14 +47,14 @@ impl ClipTextConfig {
pub fn vit_base_patch32() -> Self {
Self {
vocab_size: 49408,
embed_dim: 512,
hidden_size: 512,
intermediate_size: 2048,
max_position_embeddings: 77,
pad_with: None,
num_hidden_layers: 12,
num_attention_heads: 8,
projection_dim: 512,
activation: Activation::QuickGelu,
hidden_act: Activation::QuickGelu,
}
}
}
Expand All @@ -69,10 +71,10 @@ struct ClipTextEmbeddings {
impl ClipTextEmbeddings {
fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> {
let token_embedding =
candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
candle_nn::embedding(c.vocab_size, c.hidden_size, vs.pp("token_embedding"))?;
let position_embedding: nn::Embedding = candle_nn::embedding(
c.max_position_embeddings,
c.embed_dim,
c.hidden_size,
vs.pp("position_embedding"),
)?;
let position_ids =
Expand Down Expand Up @@ -108,13 +110,13 @@ struct ClipAttention {

impl ClipAttention {
fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
let embed_dim = c.embed_dim();
let hidden_size = c.embed_dim();
let num_attention_heads = c.num_attention_heads();
let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
let head_dim = embed_dim / num_attention_heads;
let k_proj = candle_nn::linear(hidden_size, hidden_size, vs.pp("k_proj"))?;
let v_proj = candle_nn::linear(hidden_size, hidden_size, vs.pp("v_proj"))?;
let q_proj = candle_nn::linear(hidden_size, hidden_size, vs.pp("q_proj"))?;
let out_proj = candle_nn::linear(hidden_size, hidden_size, vs.pp("out_proj"))?;
let head_dim = hidden_size / num_attention_heads;
let scale = (head_dim as f64).powf(-0.5);

Ok(ClipAttention {
Expand All @@ -136,7 +138,7 @@ impl ClipAttention {

fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
let in_dtype = xs.dtype();
let (bsz, seq_len, embed_dim) = xs.dims3()?;
let (bsz, seq_len, hidden_size) = xs.dims3()?;

let query_states = (self.q_proj.forward(xs)? * self.scale)?;
let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
Expand Down Expand Up @@ -171,7 +173,7 @@ impl ClipAttention {
let attn_output = attn_output
.reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
.transpose(1, 2)?
.reshape((bsz, seq_len, embed_dim))?;
.reshape((bsz, seq_len, hidden_size))?;
self.out_proj.forward(&attn_output)
}
}
Expand Down Expand Up @@ -290,7 +292,8 @@ impl ClipTextTransformer {
pub fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> {
let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Text(c.clone()))?;
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
let final_layer_norm =
candle_nn::layer_norm(c.hidden_size, 1e-5, vs.pp("final_layer_norm"))?;
Ok(ClipTextTransformer {
embeddings,
encoder,
Expand Down
Loading