Skip to content

Commit

Permalink
Add Stable Diffusion 3 Example (#2558)
Browse files Browse the repository at this point in the history
* Add stable diffusion 3 example

Add get_qkv_linear to handle different dimensionality in linears

Add stable diffusion 3 example

Add use_quant_conv and use_post_quant_conv for vae in stable diffusion

adapt existing AutoEncoderKLConfig to the change

add forward_until_encoder_layer to ClipTextTransformer

rename sd3 config to sd3_medium in mmdit; minor clean-up

Enable flash-attn for mmdit impl when the feature is enabled.

Add sd3 example codebase

add document

crediting references

pass the cargo fmt test

pass the clippy test

* fix typos

* expose cfg_scale and time_shift as options

* Replace the sample image with JPG version. Change image output format accordingly.

* make meaningful error messages

* remove the tail-end assignment in sd3_vae_vb_rename

* remove the CUDA requirement

* use default_value in clap args

* add use_flash_attn to turn on/off flash-attn for MMDiT at runtime

* resolve clippy errors and warnings

* use default_value_t

* Pin the web-sys dependency.

* Clippy fix.

---------

Co-authored-by: Laurent <[email protected]>
  • Loading branch information
Czxck001 and LaurentMazare authored Oct 13, 2024
1 parent 0d96ec3 commit ca7cf5c
Show file tree
Hide file tree
Showing 16 changed files with 751 additions and 34 deletions.
3 changes: 3 additions & 0 deletions candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,6 @@ required-features = ["onnx"]
[[example]]
name = "colpali"
required-features = ["pdf2image"]

[[example]]
name = "stable-diffusion-3"
54 changes: 54 additions & 0 deletions candle-examples/examples/stable-diffusion-3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# candle-stable-diffusion-3: Candle Implementation of Stable Diffusion 3 Medium

![](assets/stable-diffusion-3.jpg)

*A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k*

Stable Diffusion 3 Medium is a text-to-image model based on Multimodal Diffusion Transformer (MMDiT) architecture.

- [huggingface repo](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
- [research paper](https://arxiv.org/pdf/2403.03206)
- [announcement blog post](https://stability.ai/news/stable-diffusion-3-medium)

## Getting access to the weights

The weights of Stable Diffusion 3 Medium is released by Stability AI under the Stability Community License. You will need to accept the conditions and acquire a license by visiting the [repo on HuggingFace Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium) to gain access to the weights for your HuggingFace account.

On the first run, the weights will be automatically downloaded from the Huggingface Hub. You might be prompted to configure a [Huggingface User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens) (recommended) on your computer if you haven't done that before. After the download, the weights will be [cached](https://huggingface.co/docs/datasets/en/cache) and remain accessible locally.

## Running the model

```shell
cargo run --example stable-diffusion-3 --release --features=cuda -- \
--height 1024 --width 1024 \
--prompt 'A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k'
```

To display other options available,

```shell
cargo run --example stable-diffusion-3 --release --features=cuda -- --help
```

If GPU supports, Flash-Attention is a strongly recommended feature as it can greatly improve the speed of inference, as MMDiT is a transformer model heavily depends on attentions. To utilize [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) in the demo, you will need both `--features flash-attn` and `--use-flash-attn`.

```shell
cargo run --example stable-diffusion-3 --release --features=cuda,flash-attn -- --use-flash-attn ...
```

## Performance Benchmark

Below benchmark is done by generating 1024-by-1024 image from 28 steps of Euler sampling and measure the average speed (iteration per seconds).

[candle](https://github.com/huggingface/candle) and [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) is based on the commit of [0d96ec3](https://github.com/huggingface/candle/commit/0d96ec31e8be03f844ed0aed636d6217dee9c7bc).

System specs (Desktop PCIE 5 x8/x8 dual-GPU setup):

- Operating System: Ubuntu 23.10
- CPU: i9 12900K w/o overclocking.
- RAM: 64G dual-channel DDR5 @ 4800 MT/s

| Speed (iter/s) | w/o flash-attn | w/ flash-attn |
| -------------- | -------------- | ------------- |
| RTX 3090 Ti | 0.83 | 2.15 |
| RTX 4090 | 1.72 | 4.06 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
201 changes: 201 additions & 0 deletions candle-examples/examples/stable-diffusion-3/clip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
use anyhow::{Error as E, Ok, Result};
use candle::{DType, IndexOp, Module, Tensor, D};
use candle_transformers::models::{stable_diffusion, t5};
use tokenizers::tokenizer::Tokenizer;

struct ClipWithTokenizer {
clip: stable_diffusion::clip::ClipTextTransformer,
config: stable_diffusion::clip::Config,
tokenizer: Tokenizer,
max_position_embeddings: usize,
}

impl ClipWithTokenizer {
fn new(
vb: candle_nn::VarBuilder,
config: stable_diffusion::clip::Config,
tokenizer_path: &str,
max_position_embeddings: usize,
) -> Result<Self> {
let clip = stable_diffusion::clip::ClipTextTransformer::new(vb, &config)?;
let path_buf = hf_hub::api::sync::Api::new()?
.model(tokenizer_path.to_string())
.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(path_buf.to_str().ok_or(E::msg(
"Failed to serialize huggingface PathBuf of CLIP tokenizer",
))?)
.map_err(E::msg)?;
Ok(Self {
clip,
config,
tokenizer,
max_position_embeddings,
})
}

fn encode_text_to_embedding(
&self,
prompt: &str,
device: &candle::Device,
) -> Result<(Tensor, Tensor)> {
let pad_id = match &self.config.pad_with {
Some(padding) => *self
.tokenizer
.get_vocab(true)
.get(padding.as_str())
.ok_or(E::msg("Failed to tokenize CLIP padding."))?,
None => *self
.tokenizer
.get_vocab(true)
.get("<|endoftext|>")
.ok_or(E::msg("Failed to tokenize CLIP end-of-text."))?,
};

let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();

let eos_position = tokens.len() - 1;

while tokens.len() < self.max_position_embeddings {
tokens.push(pad_id)
}
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
let (text_embeddings, text_embeddings_penultimate) = self
.clip
.forward_until_encoder_layer(&tokens, usize::MAX, -2)?;
let text_embeddings_pooled = text_embeddings.i((0, eos_position, ..))?;

Ok((text_embeddings_penultimate, text_embeddings_pooled))
}
}

struct T5WithTokenizer {
t5: t5::T5EncoderModel,
tokenizer: Tokenizer,
max_position_embeddings: usize,
}

impl T5WithTokenizer {
fn new(vb: candle_nn::VarBuilder, max_position_embeddings: usize) -> Result<Self> {
let api = hf_hub::api::sync::Api::new()?;
let repo = api.repo(hf_hub::Repo::with_revision(
"google/t5-v1_1-xxl".to_string(),
hf_hub::RepoType::Model,
"refs/pr/2".to_string(),
));
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let config: t5::Config = serde_json::from_str(&config)?;
let model = t5::T5EncoderModel::load(vb, &config)?;

let tokenizer_filename = api
.model("lmz/mt5-tokenizers".to_string())
.get("t5-v1_1-xxl.tokenizer.json")?;

let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
Ok(Self {
t5: model,
tokenizer,
max_position_embeddings,
})
}

fn encode_text_to_embedding(
&mut self,
prompt: &str,
device: &candle::Device,
) -> Result<Tensor> {
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
tokens.resize(self.max_position_embeddings, 0);
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let embeddings = self.t5.forward(&input_token_ids)?;
Ok(embeddings)
}
}

pub struct StableDiffusion3TripleClipWithTokenizer {
clip_l: ClipWithTokenizer,
clip_g: ClipWithTokenizer,
clip_g_text_projection: candle_nn::Linear,
t5: T5WithTokenizer,
}

impl StableDiffusion3TripleClipWithTokenizer {
pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result<Self> {
let max_position_embeddings = 77usize;
let clip_l = ClipWithTokenizer::new(
vb_fp16.pp("clip_l.transformer"),
stable_diffusion::clip::Config::sdxl(),
"openai/clip-vit-large-patch14",
max_position_embeddings,
)?;

let clip_g = ClipWithTokenizer::new(
vb_fp16.pp("clip_g.transformer"),
stable_diffusion::clip::Config::sdxl2(),
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
max_position_embeddings,
)?;

let text_projection = candle_nn::linear_no_bias(
1280,
1280,
vb_fp16.pp("clip_g.transformer.text_projection"),
)?;

// Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5.
// This is a temporary workaround until the T5 implementation is updated to support fp16.
// Also see:
// https://github.com/huggingface/candle/issues/2480
// https://github.com/huggingface/candle/pull/2481
let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?;

Ok(Self {
clip_l,
clip_g,
clip_g_text_projection: text_projection,
t5,
})
}

pub fn encode_text_to_embedding(
&mut self,
prompt: &str,
device: &candle::Device,
) -> Result<(Tensor, Tensor)> {
let (clip_l_embeddings, clip_l_embeddings_pooled) =
self.clip_l.encode_text_to_embedding(prompt, device)?;
let (clip_g_embeddings, clip_g_embeddings_pooled) =
self.clip_g.encode_text_to_embedding(prompt, device)?;

let clip_g_embeddings_pooled = self
.clip_g_text_projection
.forward(&clip_g_embeddings_pooled.unsqueeze(0)?)?
.squeeze(0)?;

let y = Tensor::cat(&[&clip_l_embeddings_pooled, &clip_g_embeddings_pooled], 0)?
.unsqueeze(0)?;
let clip_embeddings_concat = Tensor::cat(
&[&clip_l_embeddings, &clip_g_embeddings],
D::Minus1,
)?
.pad_with_zeros(D::Minus1, 0, 2048)?;

let t5_embeddings = self
.t5
.encode_text_to_embedding(prompt, device)?
.to_dtype(DType::F16)?;
let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?;

Ok((context, y))
}
}
Loading

0 comments on commit ca7cf5c

Please sign in to comment.