Skip to content

Commit

Permalink
Merge remote-tracking branch 'candle/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
akshayballal95 committed Sep 29, 2024
2 parents 8d92718 + 2f49e1b commit 15d3c32
Show file tree
Hide file tree
Showing 31 changed files with 2,677 additions and 401 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci_cuda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ jobs:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on: [single-gpu, nvidia-gpu, t4, ci]
runs-on:
group: aws-g4dn-2xlarge
container:
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
options: --gpus 0
Expand Down
18 changes: 9 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ exclude = [
resolver = "2"

[workspace.package]
version = "0.7.1"
version = "0.7.2"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
Expand All @@ -33,14 +33,14 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.7.1" }
candle-datasets = { path = "./candle-datasets", version = "0.7.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.1" }
candle-kernels = { path = "./candle-kernels", version = "0.7.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.1" }
candle-nn = { path = "./candle-nn", version = "0.7.1" }
candle-onnx = { path = "./candle-onnx", version = "0.7.1" }
candle-transformers = { path = "./candle-transformers", version = "0.7.1" }
candle = { path = "./candle-core", package = "candle-core", version = "0.7.2" }
candle-datasets = { path = "./candle-datasets", version = "0.7.2" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.2" }
candle-kernels = { path = "./candle-kernels", version = "0.7.2" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.2" }
candle-nn = { path = "./candle-nn", version = "0.7.2" }
candle-onnx = { path = "./candle-onnx", version = "0.7.2" }
candle-transformers = { path = "./candle-transformers", version = "0.7.2" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
Expand Down
32 changes: 3 additions & 29 deletions candle-examples/examples/clip/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use candle_transformers::models::clip::{self, ClipConfig};

use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
use tracing::info;

#[derive(Parser)]
struct Args {
Expand Down Expand Up @@ -57,15 +56,12 @@ fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::
height as u32,
image::imageops::FilterType::Triangle,
);

let img = img.to_rgb8();

let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
// .unsqueeze(0)?;
Ok(img)
}

Expand All @@ -74,20 +70,15 @@ fn load_images<T: AsRef<std::path::Path>>(
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];

for path in paths {
let tensor = load_image(path, image_size)?;
images.push(tensor);
}

let images = Tensor::stack(&images, 0)?;

Ok(images)
}

pub fn main() -> anyhow::Result<()> {
// std::env::set_var("RUST_BACKTRACE", "full");

let args = Args::parse();

tracing_subscriber::fmt::init();
Expand Down Expand Up @@ -146,35 +137,25 @@ pub fn main() -> anyhow::Result<()> {
};

let model = clip::ClipModel::new(vb, &config)?;

let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;

let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;

let softmax_image = softmax(&logits_per_image, 1)?;

let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;

info!("softmax_image_vec: {:?}", softmax_image_vec);

println!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();

let probability_per_image = probability_vec.len() / vec_imgs.len();

for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
info!("\n\nResults for image: {}\n", img);

println!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
}
}

Ok(())
}

Expand All @@ -187,7 +168,6 @@ pub fn tokenize_sequences(
.get_vocab(true)
.get("<|endoftext|>")
.ok_or(E::msg("No pad token"))?;

let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
Expand All @@ -196,25 +176,19 @@ pub fn tokenize_sequences(
"a robot holding a candle".to_string(),
],
};

let mut tokens = vec![];

for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
}

let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);

// Pad the sequences to have the same length
for token_vec in tokens.iter_mut() {
let len_diff = max_len - token_vec.len();
if len_diff > 0 {
token_vec.extend(vec![pad_id; len_diff]);
}
}

let input_ids = Tensor::new(tokens, device)?;

Ok((input_ids, vec_seq))
}
2 changes: 1 addition & 1 deletion candle-examples/examples/flux/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ descriptions,

```bash
cargo run --features cuda --example flux -r -- \
--height 1024 --width 1024
--height 1024 --width 1024 \
--prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k"
```

83 changes: 64 additions & 19 deletions candle-examples/examples/flux/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ struct Args {
#[arg(long)]
cpu: bool,

/// Use the quantized model.
#[arg(long)]
quantized: bool,

/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
Expand All @@ -40,6 +44,10 @@ struct Args {

#[arg(long, value_enum, default_value = "schnell")]
model: Model,

/// Use the faster kernels which are buggy at the moment.
#[arg(long)]
no_dmmv: bool,
}

#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
Expand All @@ -60,6 +68,8 @@ fn run(args: Args) -> Result<()> {
tracing,
decode_only,
model,
quantized,
..
} = args;
let width = width.unwrap_or(1360);
let height = height.unwrap_or(768);
Expand Down Expand Up @@ -146,38 +156,71 @@ fn run(args: Args) -> Result<()> {
};
println!("CLIP\n{clip_emb}");
let img = {
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
};
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = match model {
Model::Dev => flux::model::Config::dev(),
Model::Schnell => flux::model::Config::schnell(),
};
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
let state = if quantized {
flux::sampling::State::new(
&t5_emb.to_dtype(candle::DType::F32)?,
&clip_emb.to_dtype(candle::DType::F32)?,
&img.to_dtype(candle::DType::F32)?,
)?
} else {
flux::sampling::State::new(&t5_emb, &clip_emb, &img)?
};
let timesteps = match model {
Model::Dev => {
flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
}
Model::Schnell => flux::sampling::get_schedule(4, None),
};
let model = flux::model::Flux::new(&cfg, vb)?;

println!("{state:?}");
println!("{timesteps:?}");
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
if quantized {
let model_file = match model {
Model::Schnell => api
.repo(hf_hub::Repo::model("lmz/candle-flux".to_string()))
.get("flux1-schnell.gguf")?,
Model::Dev => todo!(),
};
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
model_file, &device,
)?;

let model = flux::quantized_model::Flux::new(&cfg, vb)?;
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
.to_dtype(dtype)?
} else {
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
};
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)?
};
let model = flux::model::Flux::new(&cfg, vb)?;
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
}
};
flux::sampling::unpack(&img, height, width)?
}
Expand Down Expand Up @@ -206,5 +249,7 @@ fn run(args: Args) -> Result<()> {

fn main() -> Result<()> {
let args = Args::parse();
#[cfg(feature = "cuda")]
candle::quantized::cuda::set_force_dmmv(!args.no_dmmv);
run(args)
}
14 changes: 13 additions & 1 deletion candle-examples/examples/llama/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ enum Which {
V31,
V3Instruct,
V31Instruct,
V32_1b,
V32_1bInstruct,
V32_3b,
V32_3bInstruct,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
Expand Down Expand Up @@ -137,6 +141,10 @@ fn main() -> Result<()> {
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(),
Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(),
Which::V32_1b => "meta-llama/Llama-3.2-1B".to_string(),
Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct".to_string(),
Which::V32_3b => "meta-llama/Llama-3.2-3B".to_string(),
Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
Expand All @@ -156,10 +164,14 @@ fn main() -> Result<()> {
| Which::V3Instruct
| Which::V31
| Which::V31Instruct
| Which::V32_3b
| Which::V32_3bInstruct
| Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
Which::V32_1b | Which::V32_1bInstruct | Which::TinyLlama1_1BChat => {
vec![api.get("model.safetensors")?]
}
};
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;

Expand Down
Loading

0 comments on commit 15d3c32

Please sign in to comment.