Skip to content

Commit

Permalink
Blip fixes (#1145)
Browse files Browse the repository at this point in the history
* Some fixes for the blip example.

* Stop generating on sep tokens.

* Clippy fixes.

* rustfmt.
  • Loading branch information
LaurentMazare authored Oct 21, 2023
1 parent 0d9bb4e commit 2531b13
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 31 deletions.
72 changes: 68 additions & 4 deletions candle-examples/examples/blip/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,24 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use anyhow::Error as E;
use clap::Parser;

use candle::DType;
use candle::{DType, Device, Result, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::blip;

use tokenizers::Tokenizer;

#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,

#[arg(long)]
tokenizer: Option<String>,

#[arg(long)]
image: String,

Expand All @@ -23,12 +30,33 @@ struct Args {
cpu: bool,
}

const SEP_TOKEN_ID: u32 = 102;

/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 384, 384). OpenAI normalization is applied.
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
let img = image::io::Reader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (384, 384, 3), &Device::Cpu)?.permute((2, 0, 1))?;
let mean =
Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &Device::Cpu)?.reshape((3, 1, 1))?;
let std = Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], &Device::Cpu)?
.reshape((3, 1, 1))?;
(data.to_dtype(candle::DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}

pub fn main() -> anyhow::Result<()> {
let args = Args::parse();

let device = candle_examples::device(args.cpu)?;

let image = candle_examples::imagenet::load_image224(args.image)?;
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");

let model_file = match args.model {
Expand All @@ -43,12 +71,48 @@ pub fn main() -> anyhow::Result<()> {
}
Some(model) => model.into(),
};
let tokenizer = match args.tokenizer {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("Salesforce/blip-image-captioning-large".to_string());
api.get("tokenizer.json")?
}
Some(file) => file.into(),
};
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let mut tokenizer = TokenOutputStream::new(tokenizer);
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);

let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let config = blip::Config::image_captioning_large();
let model = blip::BlipForConditionalGeneration::new(&config, vb)?;
let vision_model = model.vision_model();
let text_decoder = model.text_decoder();
println!("model built");
// TODO: Maybe add support for the conditional prompt.
let out = model.generate(&image.unsqueeze(0)?, None, None)?;
println!(">>>\n{out}");
let image_embeds = image.unsqueeze(0)?.apply(vision_model)?;

let mut token_ids = vec![30522u32];
for _index in 0..1000 {
let input_ids = Tensor::new(token_ids.as_slice(), &device)?.broadcast_left(1)?;
let logits = text_decoder.forward(&input_ids, &image_embeds)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
if token == SEP_TOKEN_ID {
break;
}
token_ids.push(token);
if let Some(t) = tokenizer.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}

Ok(())
}
27 changes: 0 additions & 27 deletions candle-transformers/src/models/blip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,31 +306,4 @@ impl BlipForConditionalGeneration {
pub fn text_decoder(&self) -> &blip_text::TextLMHeadModel {
&self.text_decoder
}

pub fn generate(
&self,
pixel_values: &Tensor,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let image_embeds = pixel_values.apply(&self.vision_model)?;
let b_size = image_embeds.dim(0)?;
if b_size > 1 {
candle::bail!("only a batch size of 1 is supported")
}
let mut logits_processor = crate::generation::LogitsProcessor::new(1337, None, None);
let mut token_ids = vec![30522u32];
for i in 0..1000 {
let input_ids =
Tensor::new(token_ids.as_slice(), pixel_values.device())?.broadcast_left(b_size)?;
let logits = self.text_decoder.forward(&input_ids, &image_embeds)?;
println!("{logits:?}");
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
println!("{token}");
token_ids.push(token)
}
todo!()
}
}

0 comments on commit 2531b13

Please sign in to comment.