From fa1e0e438e8f8424948785f6abe8d98c8ababc4e Mon Sep 17 00:00:00 2001 From: Laurent Date: Wed, 25 Sep 2024 11:07:49 +0200 Subject: [PATCH] Quantized version of flux. --- candle-transformers/src/models/flux/mod.rs | 1 + .../src/models/flux/quantized_model.rs | 626 ++++++++++++++++++ 2 files changed, 627 insertions(+) create mode 100644 candle-transformers/src/models/flux/quantized_model.rs diff --git a/candle-transformers/src/models/flux/mod.rs b/candle-transformers/src/models/flux/mod.rs index 763fa90da1..340d23c209 100644 --- a/candle-transformers/src/models/flux/mod.rs +++ b/candle-transformers/src/models/flux/mod.rs @@ -1,3 +1,4 @@ pub mod autoencoder; pub mod model; +pub mod quantized_model; pub mod sampling; diff --git a/candle-transformers/src/models/flux/quantized_model.rs b/candle-transformers/src/models/flux/quantized_model.rs new file mode 100644 index 0000000000..aba9ad1344 --- /dev/null +++ b/candle-transformers/src/models/flux/quantized_model.rs @@ -0,0 +1,626 @@ +use crate::quantized_nn::{linear, linear_b, Linear}; +use crate::quantized_var_builder::VarBuilder; +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::{LayerNorm, RmsNorm}; + +// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/model.py#L12 +#[derive(Debug, Clone)] +pub struct Config { + pub in_channels: usize, + pub vec_in_dim: usize, + pub context_in_dim: usize, + pub hidden_size: usize, + pub mlp_ratio: f64, + pub num_heads: usize, + pub depth: usize, + pub depth_single_blocks: usize, + pub axes_dim: Vec, + pub theta: usize, + pub qkv_bias: bool, + pub guidance_embed: bool, +} + +impl Config { + // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L32 + pub fn dev() -> Self { + Self { + in_channels: 64, + vec_in_dim: 768, + context_in_dim: 4096, + hidden_size: 3072, + mlp_ratio: 4.0, + num_heads: 24, + depth: 19, + depth_single_blocks: 38, + axes_dim: vec![16, 56, 56], + theta: 10_000, + qkv_bias: true, + guidance_embed: true, + } + } + + // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L64 + pub fn schnell() -> Self { + Self { + in_channels: 64, + vec_in_dim: 768, + context_in_dim: 4096, + hidden_size: 3072, + mlp_ratio: 4.0, + num_heads: 24, + depth: 19, + depth_single_blocks: 38, + axes_dim: vec![16, 56, 56], + theta: 10_000, + qkv_bias: true, + guidance_embed: false, + } + } +} + +fn layer_norm(dim: usize, vb: VarBuilder) -> Result { + let ws = Tensor::ones(dim, DType::F32, vb.device())?; + Ok(LayerNorm::new_no_bias(ws, 1e-6)) +} + +fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result { + let dim = q.dim(D::Minus1)?; + let scale_factor = 1.0 / (dim as f64).sqrt(); + let mut batch_dims = q.dims().to_vec(); + batch_dims.pop(); + batch_dims.pop(); + let q = q.flatten_to(batch_dims.len() - 1)?; + let k = k.flatten_to(batch_dims.len() - 1)?; + let v = v.flatten_to(batch_dims.len() - 1)?; + let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?; + let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?; + batch_dims.push(attn_scores.dim(D::Minus2)?); + batch_dims.push(attn_scores.dim(D::Minus1)?); + attn_scores.reshape(batch_dims) +} + +fn rope(pos: &Tensor, dim: usize, theta: usize) -> Result { + if dim % 2 == 1 { + candle::bail!("dim {dim} is odd") + } + let dev = pos.device(); + let theta = theta as f64; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, 1, inv_freq_len), dev)?; + let inv_freq = inv_freq.to_dtype(pos.dtype())?; + let freqs = pos.unsqueeze(2)?.broadcast_mul(&inv_freq)?; + let cos = freqs.cos()?; + let sin = freqs.sin()?; + let out = Tensor::stack(&[&cos, &sin.neg()?, &sin, &cos], 3)?; + let (b, n, d, _ij) = out.dims4()?; + out.reshape((b, n, d, 2, 2)) +} + +fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result { + let dims = x.dims(); + let (b_sz, n_head, seq_len, n_embd) = x.dims4()?; + let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; + let x0 = x.narrow(D::Minus1, 0, 1)?; + let x1 = x.narrow(D::Minus1, 1, 1)?; + let fr0 = freq_cis.get_on_dim(D::Minus1, 0)?; + let fr1 = freq_cis.get_on_dim(D::Minus1, 1)?; + (fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec()) +} + +fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result { + let q = apply_rope(q, pe)?.contiguous()?; + let k = apply_rope(k, pe)?.contiguous()?; + let x = scaled_dot_product_attention(&q, &k, v)?; + x.transpose(1, 2)?.flatten_from(2) +} + +fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result { + const TIME_FACTOR: f64 = 1000.; + const MAX_PERIOD: f64 = 10000.; + if dim % 2 == 1 { + candle::bail!("{dim} is odd") + } + let dev = t.device(); + let half = dim / 2; + let t = (t * TIME_FACTOR)?; + let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(candle::DType::F32)?; + let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?; + let args = t + .unsqueeze(1)? + .to_dtype(candle::DType::F32)? + .broadcast_mul(&freqs.unsqueeze(0)?)?; + let emb = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)?; + Ok(emb) +} + +#[derive(Debug, Clone)] +pub struct EmbedNd { + #[allow(unused)] + dim: usize, + theta: usize, + axes_dim: Vec, +} + +impl EmbedNd { + fn new(dim: usize, theta: usize, axes_dim: Vec) -> Self { + Self { + dim, + theta, + axes_dim, + } + } +} + +impl candle::Module for EmbedNd { + fn forward(&self, ids: &Tensor) -> Result { + let n_axes = ids.dim(D::Minus1)?; + let mut emb = Vec::with_capacity(n_axes); + for idx in 0..n_axes { + let r = rope( + &ids.get_on_dim(D::Minus1, idx)?, + self.axes_dim[idx], + self.theta, + )?; + emb.push(r) + } + let emb = Tensor::cat(&emb, 2)?; + emb.unsqueeze(1) + } +} + +#[derive(Debug, Clone)] +pub struct MlpEmbedder { + in_layer: Linear, + out_layer: Linear, +} + +impl MlpEmbedder { + fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result { + let in_layer = linear(in_sz, h_sz, vb.pp("in_layer"))?; + let out_layer = linear(h_sz, h_sz, vb.pp("out_layer"))?; + Ok(Self { + in_layer, + out_layer, + }) + } +} + +impl candle::Module for MlpEmbedder { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer) + } +} + +#[derive(Debug, Clone)] +pub struct QkNorm { + query_norm: RmsNorm, + key_norm: RmsNorm, +} + +impl QkNorm { + fn new(dim: usize, vb: VarBuilder) -> Result { + let query_norm = vb.get(dim, "query_norm.scale")?.dequantize(vb.device())?; + let query_norm = RmsNorm::new(query_norm, 1e-6); + let key_norm = vb.get(dim, "key_norm.scale")?.dequantize(vb.device())?; + let key_norm = RmsNorm::new(key_norm, 1e-6); + Ok(Self { + query_norm, + key_norm, + }) + } +} + +struct ModulationOut { + shift: Tensor, + scale: Tensor, + gate: Tensor, +} + +impl ModulationOut { + fn scale_shift(&self, xs: &Tensor) -> Result { + xs.broadcast_mul(&(&self.scale + 1.)?)? + .broadcast_add(&self.shift) + } + + fn gate(&self, xs: &Tensor) -> Result { + self.gate.broadcast_mul(xs) + } +} + +#[derive(Debug, Clone)] +struct Modulation1 { + lin: Linear, +} + +impl Modulation1 { + fn new(dim: usize, vb: VarBuilder) -> Result { + let lin = linear(dim, 3 * dim, vb.pp("lin"))?; + Ok(Self { lin }) + } + + fn forward(&self, vec_: &Tensor) -> Result { + let ys = vec_ + .silu()? + .apply(&self.lin)? + .unsqueeze(1)? + .chunk(3, D::Minus1)?; + if ys.len() != 3 { + candle::bail!("unexpected len from chunk {ys:?}") + } + Ok(ModulationOut { + shift: ys[0].clone(), + scale: ys[1].clone(), + gate: ys[2].clone(), + }) + } +} + +#[derive(Debug, Clone)] +struct Modulation2 { + lin: Linear, +} + +impl Modulation2 { + fn new(dim: usize, vb: VarBuilder) -> Result { + let lin = linear(dim, 6 * dim, vb.pp("lin"))?; + Ok(Self { lin }) + } + + fn forward(&self, vec_: &Tensor) -> Result<(ModulationOut, ModulationOut)> { + let ys = vec_ + .silu()? + .apply(&self.lin)? + .unsqueeze(1)? + .chunk(6, D::Minus1)?; + if ys.len() != 6 { + candle::bail!("unexpected len from chunk {ys:?}") + } + let mod1 = ModulationOut { + shift: ys[0].clone(), + scale: ys[1].clone(), + gate: ys[2].clone(), + }; + let mod2 = ModulationOut { + shift: ys[3].clone(), + scale: ys[4].clone(), + gate: ys[5].clone(), + }; + Ok((mod1, mod2)) + } +} + +#[derive(Debug, Clone)] +pub struct SelfAttention { + qkv: Linear, + norm: QkNorm, + proj: Linear, + num_heads: usize, +} + +impl SelfAttention { + fn new(dim: usize, num_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result { + let head_dim = dim / num_heads; + let qkv = linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?; + let norm = QkNorm::new(head_dim, vb.pp("norm"))?; + let proj = linear(dim, dim, vb.pp("proj"))?; + Ok(Self { + qkv, + norm, + proj, + num_heads, + }) + } + + fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { + let qkv = xs.apply(&self.qkv)?; + let (b, l, _khd) = qkv.dims3()?; + let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?; + let q = qkv.i((.., .., 0))?.transpose(1, 2)?; + let k = qkv.i((.., .., 1))?.transpose(1, 2)?; + let v = qkv.i((.., .., 2))?.transpose(1, 2)?; + let q = q.apply(&self.norm.query_norm)?; + let k = k.apply(&self.norm.key_norm)?; + Ok((q, k, v)) + } + + #[allow(unused)] + fn forward(&self, xs: &Tensor, pe: &Tensor) -> Result { + let (q, k, v) = self.qkv(xs)?; + attention(&q, &k, &v, pe)?.apply(&self.proj) + } +} + +#[derive(Debug, Clone)] +struct Mlp { + lin1: Linear, + lin2: Linear, +} + +impl Mlp { + fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result { + let lin1 = linear(in_sz, mlp_sz, vb.pp("0"))?; + let lin2 = linear(mlp_sz, in_sz, vb.pp("2"))?; + Ok(Self { lin1, lin2 }) + } +} + +impl candle::Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2) + } +} + +#[derive(Debug, Clone)] +pub struct DoubleStreamBlock { + img_mod: Modulation2, + img_norm1: LayerNorm, + img_attn: SelfAttention, + img_norm2: LayerNorm, + img_mlp: Mlp, + txt_mod: Modulation2, + txt_norm1: LayerNorm, + txt_attn: SelfAttention, + txt_norm2: LayerNorm, + txt_mlp: Mlp, +} + +impl DoubleStreamBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let h_sz = cfg.hidden_size; + let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize; + let img_mod = Modulation2::new(h_sz, vb.pp("img_mod"))?; + let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?; + let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?; + let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?; + let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?; + let txt_mod = Modulation2::new(h_sz, vb.pp("txt_mod"))?; + let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?; + let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?; + let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?; + let txt_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("txt_mlp"))?; + Ok(Self { + img_mod, + img_norm1, + img_attn, + img_norm2, + img_mlp, + txt_mod, + txt_norm1, + txt_attn, + txt_norm2, + txt_mlp, + }) + } + + fn forward( + &self, + img: &Tensor, + txt: &Tensor, + vec_: &Tensor, + pe: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let (img_mod1, img_mod2) = self.img_mod.forward(vec_)?; // shift, scale, gate + let (txt_mod1, txt_mod2) = self.txt_mod.forward(vec_)?; // shift, scale, gate + let img_modulated = img.apply(&self.img_norm1)?; + let img_modulated = img_mod1.scale_shift(&img_modulated)?; + let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?; + + let txt_modulated = txt.apply(&self.txt_norm1)?; + let txt_modulated = txt_mod1.scale_shift(&txt_modulated)?; + let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?; + + let q = Tensor::cat(&[txt_q, img_q], 2)?; + let k = Tensor::cat(&[txt_k, img_k], 2)?; + let v = Tensor::cat(&[txt_v, img_v], 2)?; + + let attn = attention(&q, &k, &v, pe)?; + let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?; + let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?; + + let img = (img + img_mod1.gate(&img_attn.apply(&self.img_attn.proj)?))?; + let img = (&img + + img_mod2.gate( + &img_mod2 + .scale_shift(&img.apply(&self.img_norm2)?)? + .apply(&self.img_mlp)?, + )?)?; + + let txt = (txt + txt_mod1.gate(&txt_attn.apply(&self.txt_attn.proj)?))?; + let txt = (&txt + + txt_mod2.gate( + &txt_mod2 + .scale_shift(&txt.apply(&self.txt_norm2)?)? + .apply(&self.txt_mlp)?, + )?)?; + + Ok((img, txt)) + } +} + +#[derive(Debug, Clone)] +pub struct SingleStreamBlock { + linear1: Linear, + linear2: Linear, + norm: QkNorm, + pre_norm: LayerNorm, + modulation: Modulation1, + h_sz: usize, + mlp_sz: usize, + num_heads: usize, +} + +impl SingleStreamBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let h_sz = cfg.hidden_size; + let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize; + let head_dim = h_sz / cfg.num_heads; + let linear1 = linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?; + let linear2 = linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?; + let norm = QkNorm::new(head_dim, vb.pp("norm"))?; + let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?; + let modulation = Modulation1::new(h_sz, vb.pp("modulation"))?; + Ok(Self { + linear1, + linear2, + norm, + pre_norm, + modulation, + h_sz, + mlp_sz, + num_heads: cfg.num_heads, + }) + } + + fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result { + let mod_ = self.modulation.forward(vec_)?; + let x_mod = mod_.scale_shift(&xs.apply(&self.pre_norm)?)?; + let x_mod = x_mod.apply(&self.linear1)?; + let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?; + let (b, l, _khd) = qkv.dims3()?; + let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?; + let q = qkv.i((.., .., 0))?.transpose(1, 2)?; + let k = qkv.i((.., .., 1))?.transpose(1, 2)?; + let v = qkv.i((.., .., 2))?.transpose(1, 2)?; + let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?; + let q = q.apply(&self.norm.query_norm)?; + let k = k.apply(&self.norm.key_norm)?; + let attn = attention(&q, &k, &v, pe)?; + let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?; + xs + mod_.gate(&output) + } +} + +#[derive(Debug, Clone)] +pub struct LastLayer { + norm_final: LayerNorm, + linear: Linear, + ada_ln_modulation: Linear, +} + +impl LastLayer { + fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result { + let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?; + let linear_ = linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?; + let ada_ln_modulation = linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?; + Ok(Self { + norm_final, + linear: linear_, + ada_ln_modulation, + }) + } + + fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result { + let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?; + let (shift, scale) = (&chunks[0], &chunks[1]); + let xs = xs + .apply(&self.norm_final)? + .broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)? + .broadcast_add(&shift.unsqueeze(1)?)?; + xs.apply(&self.linear) + } +} + +#[derive(Debug, Clone)] +pub struct Flux { + img_in: Linear, + txt_in: Linear, + time_in: MlpEmbedder, + vector_in: MlpEmbedder, + guidance_in: Option, + pe_embedder: EmbedNd, + double_blocks: Vec, + single_blocks: Vec, + final_layer: LastLayer, +} + +impl Flux { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let img_in = linear(cfg.in_channels, cfg.hidden_size, vb.pp("img_in"))?; + let txt_in = linear(cfg.context_in_dim, cfg.hidden_size, vb.pp("txt_in"))?; + let mut double_blocks = Vec::with_capacity(cfg.depth); + let vb_d = vb.pp("double_blocks"); + for idx in 0..cfg.depth { + let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx))?; + double_blocks.push(db) + } + let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks); + let vb_s = vb.pp("single_blocks"); + for idx in 0..cfg.depth_single_blocks { + let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx))?; + single_blocks.push(sb) + } + let time_in = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("time_in"))?; + let vector_in = MlpEmbedder::new(cfg.vec_in_dim, cfg.hidden_size, vb.pp("vector_in"))?; + let guidance_in = if cfg.guidance_embed { + let mlp = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("guidance_in"))?; + Some(mlp) + } else { + None + }; + let final_layer = + LastLayer::new(cfg.hidden_size, 1, cfg.in_channels, vb.pp("final_layer"))?; + let pe_dim = cfg.hidden_size / cfg.num_heads; + let pe_embedder = EmbedNd::new(pe_dim, cfg.theta, cfg.axes_dim.to_vec()); + Ok(Self { + img_in, + txt_in, + time_in, + vector_in, + guidance_in, + pe_embedder, + double_blocks, + single_blocks, + final_layer, + }) + } + + #[allow(clippy::too_many_arguments)] + pub fn forward( + &self, + img: &Tensor, + img_ids: &Tensor, + txt: &Tensor, + txt_ids: &Tensor, + timesteps: &Tensor, + y: &Tensor, + guidance: Option<&Tensor>, + ) -> Result { + if txt.rank() != 3 { + candle::bail!("unexpected shape for txt {:?}", txt.shape()) + } + if img.rank() != 3 { + candle::bail!("unexpected shape for img {:?}", img.shape()) + } + let dtype = img.dtype(); + let pe = { + let ids = Tensor::cat(&[txt_ids, img_ids], 1)?; + ids.apply(&self.pe_embedder)? + }; + let mut txt = txt.apply(&self.txt_in)?; + let mut img = img.apply(&self.img_in)?; + let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?; + let vec_ = match (self.guidance_in.as_ref(), guidance) { + (Some(g_in), Some(guidance)) => { + (vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))? + } + _ => vec_, + }; + let vec_ = (vec_ + y.apply(&self.vector_in))?; + + // Double blocks + for block in self.double_blocks.iter() { + (img, txt) = block.forward(&img, &txt, &vec_, &pe)? + } + // Single blocks + let mut img = Tensor::cat(&[&txt, &img], 1)?; + for block in self.single_blocks.iter() { + img = block.forward(&img, &vec_, &pe)?; + } + let img = img.i((.., txt.dim(1)?..))?; + self.final_layer.forward(&img, &vec_) + } +}