From 1735e4831e3aa29c1a107e55ba44f0a606488837 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 21 Jul 2023 15:10:51 +0000 Subject: [PATCH 1/8] TP sharding v2 --- Cargo.toml | 6 +- candle-core/src/safetensors.rs | 43 +- candle-examples/Cargo.toml | 7 + .../examples/llama_multiprocess/main.rs | 266 ++++++++++ .../examples/llama_multiprocess/model.rs | 465 ++++++++++++++++++ candle-nn/Cargo.toml | 1 + candle-nn/src/var_builder.rs | 58 ++- candle-wasm-examples/whisper/Cargo.toml | 1 + candle-wasm-examples/whisper/src/worker.rs | 4 +- 9 files changed, 833 insertions(+), 18 deletions(-) create mode 100644 candle-examples/examples/llama_multiprocess/main.rs create mode 100644 candle-examples/examples/llama_multiprocess/model.rs diff --git a/Cargo.toml b/Cargo.toml index 0dec835b7..d9613cdce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,8 +19,10 @@ byteorder = "1.4.3" clap = { version = "4.2.4", features = ["derive"] } # Re-enable this once 0.9.13 as been released as it would include the cublas-f16 changes # cudarc = { version = "0.9.13", optional = true, features = ["f16"] } -cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", features = ["f16"] } -# TODO: Switch back to the official gemm implementation if we manage to upstream the changes. +cudarc = { git = "https://github.com/coreylowman/cudarc.git", features = ["f16", "nccl"] } +# TODO: Switch back to the official gemm implementation once the following are available. +# https://github.com/sarah-ek/gemm/pull/8. +# https://github.com/sarah-ek/gemm/pull/9. gemm = { git = "https://github.com/LaurentMazare/gemm.git" } hf-hub = "0.1.3" half = { version = "2.3.1", features = ["num-traits", "rand_distr"] } diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 3bb069a97..e81fe184e 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -1,6 +1,7 @@ use crate::{DType, Device, Error, Result, Tensor, WithDType}; +use safetensors::slice::SliceIterator; use safetensors::tensor as st; -pub use safetensors::tensor::SafeTensors; +use safetensors::tensor::{Dtype, SafeTensors}; use std::borrow::Cow; impl From for st::Dtype { @@ -63,15 +64,15 @@ impl Tensor { } } -fn convert_(view: &st::TensorView<'_>, device: &Device) -> Result { - let v = view.data(); +fn convert_slice(data: &[u8], shape: &[usize], device: &Device) -> Result { let size_in_bytes = T::DTYPE.size_in_bytes(); - let elem_count = v.len() / size_in_bytes; - if (v.as_ptr() as usize) % size_in_bytes == 0 { + let elem_count = data.len() / size_in_bytes; + if (data.as_ptr() as usize) % size_in_bytes == 0 { // SAFETY This is safe because we just checked that this // was correctly aligned. - let data: &[T] = unsafe { std::slice::from_raw_parts(v.as_ptr() as *const T, elem_count) }; - Tensor::from_slice(data, view.shape(), device) + let data: &[T] = + unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) }; + Tensor::from_slice(data, shape, device) } else { // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access @@ -81,13 +82,17 @@ fn convert_(view: &st::TensorView<'_>, device: &Device) -> Result< // We're downgrading the `c` pointer from T to u8, which removes alignment // constraints. unsafe { - std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len()); + std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len()); c.set_len(elem_count) } - Tensor::from_slice(&c, view.shape(), device) + Tensor::from_slice(&c, shape, device) } } +fn convert_(view: &st::TensorView<'_>, device: &Device) -> Result { + convert_slice::(view.data(), view.shape(), device) +} + fn convert_back_(mut vs: Vec) -> Vec { let size_in_bytes = T::DTYPE.size_in_bytes(); let length = vs.len() * size_in_bytes; @@ -112,6 +117,26 @@ impl<'a> Load for st::TensorView<'a> { } } +impl Tensor { + pub fn from_safetensors_slice( + iterator: SliceIterator, + dtype: Dtype, + shape: &[usize], + device: &Device, + ) -> Result { + let data: Vec = iterator.into_iter().flatten().cloned().collect(); + match dtype { + st::Dtype::U8 => convert_slice::(&data, shape, device), + st::Dtype::U32 => convert_slice::(&data, shape, device), + st::Dtype::BF16 => convert_slice::(&data, shape, device), + st::Dtype::F16 => convert_slice::(&data, shape, device), + st::Dtype::F32 => convert_slice::(&data, shape, device), + st::Dtype::F64 => convert_slice::(&data, shape, device), + dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), + } + } +} + pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { match view.dtype() { st::Dtype::U8 => convert_::(view, device), diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 4a989760f..2ecc85003 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -19,6 +19,8 @@ serde = { workspace = true } serde_json = { workspace = true } num-traits = { workspace = true } intel-mkl-src = { workspace = true, optional = true } +cudarc = { workspace = true, optional = true } +half = { workspace = true, optional = true } [dev-dependencies] anyhow = { workspace = true } @@ -40,3 +42,8 @@ default = [] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] flash-attn = ["cuda", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] +nccl = ["dep:cudarc", "dep:half"] + +[[example]] +name = "llama_multiprocess" +required-features = ["cuda", "nccl"] diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs new file mode 100644 index 000000000..22c121dd4 --- /dev/null +++ b/candle-examples/examples/llama_multiprocess/main.rs @@ -0,0 +1,266 @@ +// An implementation of LLaMA https://github.com/facebookresearch/llama +// +// This is based on nanoGPT in a similar way to: +// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py +// +// The tokenizer config can be retrieved from: +// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json +// +// In order to convert the llama weights to a .npz file, run: +// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use anyhow::{Error as E, Result}; +use clap::Parser; + +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use cudarc::driver::safe::CudaDevice; +use cudarc::nccl::safe::{Comm, Id}; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use std::io::Write; +use std::rc::Rc; + +mod model; +use model::{Config, Llama}; + +const MAX_SEQ_LEN: usize = 4096; +const DEFAULT_PROMPT: &str = r" +EDWARD: +I wonder how our princely father 'scaped, +Or whether he be 'scaped away or no +From Clifford's and Northumberland's pursuit: +Had he been ta'en, we should have heard the news; +Had he been slain, we should have heard the news; +Or had he 'scaped, methinks we should have heard +The happy tidings of his good escape. +How fares my brother? why is he so sad? + +RICHARD: +I cannot joy, until I be resolved +Where our right valiant father is become. +I saw him in the battle range about; +And watch'd him how he singled Clifford forth. +Methought he bore him in the thickest troop +As doth a lion in a herd of neat; +Or as a bear, encompass'd round with dogs, +Who having pinch'd a few and made them cry, +The rest stand all aloof, and bark at him. +So fared our father with his enemies; +So fled his enemies my warlike father: +Methinks, 'tis prize enough to be his son. +See how the morning opes her golden gates, +And takes her farewell of the glorious sun! +How well resembles it the prime of youth, +Trimm'd like a younker prancing to his love! + +EDWARD: +Dazzle mine eyes, or do I see three suns? + +RICHARD: +Three glorious suns, each one a perfect sun; +Not separated with the racking clouds, +But sever'd in a pale clear-shining sky. +See, see! they join, embrace, and seem to kiss, +As if they vow'd some league inviolable: +Now are they but one lamp, one light, one sun. +In this the heaven figures some event. + +EDWARD: +'Tis wondrous strange, the like yet never heard of. +I think it cites us, brother, to the field, +That we, the sons of brave Plantagenet, +Each one already blazing by our meeds, +Should notwithstanding join our lights together +And over-shine the earth as this the world. +Whate'er it bodes, henceforward will I bear +Upon my target three fair-shining suns. +"; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(long)] + num_shards: usize, + + #[arg(long)] + rank: Option, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, default_value_t = 100)] + sample_len: usize, + + /// Disable the key-value cache. + #[arg(long)] + no_kv_cache: bool, + + /// The initial prompt. + #[arg(long)] + prompt: Option, + + /// Use f32 computations rather than f16. + #[arg(long)] + use_f32: bool, + + #[arg(long)] + model_id: Option, + + #[arg(long)] + v2: bool, +} + +fn main() -> Result<()> { + use tokenizers::Tokenizer; + + let args = Args::parse(); + + let config = Config::config_7b(); + let dtype = if args.use_f32 { DType::F32 } else { DType::F16 }; + + let api = Api::new()?; + + let model_id = args.model_id.unwrap_or_else(|| { + if args.v2 { + "meta-llama/Llama-2-7b-hf".to_string() + } else { + "Narsil/amall-7b".to_string() + } + }); + println!("loading the model weights from {model_id}"); + let repo = Repo::new(model_id, RepoType::Model); + let tokenizer_filename = api.get(&repo, "tokenizer.json")?; + let mut filenames = vec![]; + for rfilename in [ + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + ] { + let filename = api.get(&repo, rfilename)?; + filenames.push(filename); + } + + if args.rank.is_none() { + let children: Vec<_> = (0..args.num_shards) + .map(|rank| { + let mut args: std::collections::VecDeque<_> = std::env::args().collect(); + args.push_back("--rank".to_string()); + args.push_back(format!("{rank}")); + let name = args.pop_front().unwrap(); + std::process::Command::new(name).args(args).spawn().unwrap() + }) + .collect(); + for mut child in children { + child.wait().unwrap(); + } + return Ok(()); + } + + let i = args.rank.unwrap(); + let num_shards = args.num_shards; + let rank = i; + // Primitive IPC + let id = if rank == 0 { + let id = Id::new().unwrap(); + std::fs::File::create("nccl_id.txt.tmp")? + .write_all(&id.internal().iter().map(|&i| i as u8).collect::>()) + .unwrap(); + std::fs::rename("nccl_id.txt.tmp", "nccl_id.txt")?; + id + } else { + let path = std::path::PathBuf::from("nccl_id.txt"); + while !path.exists() { + std::thread::sleep(std::time::Duration::from_secs(1)); + } + let data = std::fs::read("nccl_id.txt")?; + let internal: [i8; 128] = data + .into_iter() + .map(|i| i as i8) + .collect::>() + .try_into() + .unwrap(); + let id: Id = Id::uninit(internal); + id + }; + let device = CudaDevice::new(i)?; + let comm = Rc::new(Comm::from_rank(device, i, num_shards, id).unwrap()); + if rank == 0 { + std::fs::remove_file("nccl_id.txt")?; + } + println!("Rank {rank:?} spawned"); + + let device = Device::new_cuda(i)?; + let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?; + + println!("building the model"); + let handles = filenames + .iter() + .map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? })) + .collect::>>()?; + let tensors: Vec<_> = handles + .iter() + .map(|h| Ok(h.deserialize()?)) + .collect::>>()?; + + let vb = VarBuilder::from_safetensors(tensors, dtype, &device); + let llama = Llama::load(vb, &cache, &config, comm)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str()); + let mut tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + + println!("starting the inference loop"); + let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); + let mut new_tokens = vec![]; + let start_gen = std::time::Instant::now(); + let mut index_pos = 0; + for index in 0..args.sample_len { + let start_gen = std::time::Instant::now(); + let context_size = if cache.use_kv_cache && index > 0 { + 1 + } else { + tokens.len() + }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; + let logits = llama.forward(&input, index_pos)?; + let logits = logits.squeeze(0)?; + index_pos += ctxt.len(); + + let next_token = logits_processor.sample(&logits)?; + tokens.push(next_token); + new_tokens.push(next_token); + println!("> {:?}", start_gen.elapsed()); + println!( + "{} token: {} '{}'", + index + 1, + next_token, + tokenizer.decode(vec![next_token], true).map_err(E::msg)? + ); + } + let dt = start_gen.elapsed(); + println!( + "{} tokens generated ({} token/s)\n----\n{}\n----", + args.sample_len, + args.sample_len as f64 / dt.as_secs_f64(), + tokenizer.decode(new_tokens, true).map_err(E::msg)? + ); + Ok(()) +} diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs new file mode 100644 index 000000000..4e46b5267 --- /dev/null +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -0,0 +1,465 @@ +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Embedding, Linear, VarBuilder}; +use cudarc::driver::safe::CudaSlice; +use cudarc::nccl::safe::{Comm, ReduceOp}; +use half::f16; +use std::collections::HashMap; +use std::rc::Rc; +use std::sync::{Arc, Mutex}; + +use super::MAX_SEQ_LEN; + +struct TensorParallelColumnLinear { + linear: Linear, +} + +impl TensorParallelColumnLinear { + fn new(linear: Linear) -> Self { + Self { linear } + } + fn forward(&self, x: &Tensor) -> Result { + self.linear.forward(x) + } +} + +struct TensorParallelRowLinear { + linear: Linear, + comm: Rc, +} + +fn all_reduce_sum(x: &Tensor, comm: &Rc) -> Result { + Ok(x.clone()) + // let n = x.shape().elem_count(); + // let cuda_slice: CudaSlice = x.try_into()?; + // let dev = cuda_slice.device(); + // let mut slice_receive = dev.alloc_zeros(n).unwrap(); + // comm.all_reduce(cuda_slice, &mut slice_receive, &ReduceOp::Sum).unwrap(); + // Tensor::from_raw_storage(slice_receive, x.shape()) +} + +impl TensorParallelRowLinear { + fn new(linear: Linear, comm: Rc) -> Self { + Self { linear, comm } + } + fn forward(&self, x: &Tensor) -> Result { + let x = self.linear.forward(x)?; + all_reduce_sum(&x, &self.comm) + } +} + +impl TensorParallelColumnLinear { + fn load(vb: VarBuilder, comm: Rc) -> Result { + let rank = comm.rank(); + let size = comm.world_size(); + let weight = vb.get_sharded("weight", 0, rank, size)?; + Ok(Self::new(Linear::new(weight, None))) + } + + fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc) -> Result { + let rank = comm.rank(); + let size = comm.world_size(); + let weights: Vec<_> = prefixes + .iter() + .map(|p| vb.pp(p).get_sharded("weight", 0, rank, size).unwrap()) + .collect(); + let weight = Tensor::cat(&weights, 0)?; + Ok(Self::new(Linear::new(weight, None))) + } +} + +impl TensorParallelRowLinear { + fn load(vb: VarBuilder, comm: Rc) -> Result { + let rank = comm.rank(); + let size = comm.world_size(); + let weight = vb.get_sharded("weight", 1, rank, size)?; + Ok(Self::new(Linear::new(weight, None), comm.clone())) + } +} + +pub struct Config { + pub hidden_size: usize, + pub intermediate_size: usize, + pub vocab_size: usize, + pub n_layer: usize, + pub n_head: usize, + pub n_embd: usize, + pub n_key_value_head: usize, +} + +impl Config { + pub fn config_7b() -> Self { + Self { + hidden_size: 4096, + intermediate_size: 11008, + vocab_size: 32000, + n_layer: 32, + n_head: 32, + n_embd: 4096, + n_key_value_head: 32, + } + } +} + +#[derive(Clone)] +pub struct Cache { + masks: Arc>>, + pub use_kv_cache: bool, + #[allow(clippy::type_complexity)] + kvs: Arc>>>, + cos: Tensor, + sin: Tensor, + device: Device, +} + +impl Cache { + pub fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Result { + // precompute freqs_cis + let n_elem = config.n_embd / config.n_head; + let theta: Vec<_> = (0..n_elem) + .step_by(2) + .map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), device)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + // This is different from the paper, see: + // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 + let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok(Self { + masks: Arc::new(Mutex::new(HashMap::new())), + use_kv_cache, + kvs: Arc::new(Mutex::new(vec![None; config.n_layer])), + device: device.clone(), + cos, + sin, + }) + } + + fn mask(&self, t: usize) -> Result { + let mut masks = self.masks.lock().unwrap(); + if let Some(mask) = masks.get(&t) { + Ok(mask.clone()) + } else { + // TODO: If we support bool or u8 tensors, this would be better. + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u32::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; + masks.insert(t, mask.clone()); + Ok(mask) + } + } +} + +fn silu(xs: &Tensor) -> Result { + xs / (xs.neg()?.exp()? + 1.0)? +} + +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { + let weight = vb.get((size2, size1), "weight")?; + Ok(Linear::new(weight, None)) +} + +fn embedding(cfg: &Config, vb: VarBuilder) -> Result { + let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?; + Ok(Embedding::new(embeddings, cfg.hidden_size)) +} + +struct RmsNorm { + scale: Tensor, +} + +impl RmsNorm { + fn load(size: usize, vb: VarBuilder) -> Result { + let scale = vb.get(size, "weight")?; + Ok(Self::new(scale)) + } + + fn new(scale: Tensor) -> Self { + Self { scale } + } + + fn forward(&self, x: &Tensor) -> Result { + let in_dtype = x.dtype(); + // This is a no-op if x's dtype is already f32. + let x = x.to_dtype(DType::F32)?; + let (b_sz, seq_len, hidden_size) = x.shape().r3()?; + let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; + let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; + let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?; + let size = self.scale.shape().r1()?; + let scale = self + .scale + .to_dtype(DType::F32)? + .broadcast_as((b_sz, seq_len, size))?; + let x = (scale * x_normed)?; + let x = x.to_dtype(in_dtype)?; + Ok(x) + } +} + +struct CausalSelfAttention { + qkv_proj: TensorParallelColumnLinear, + o_proj: TensorParallelRowLinear, + n_head: usize, + n_key_value_head: usize, + head_dim: usize, + cache: Cache, +} + +impl CausalSelfAttention { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { + let (b_sz, _, seq_len, n_embd) = x.shape().r4()?; + let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; + let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; + let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?; + let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?; + let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?; + let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?; + let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; + let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?; + Ok(rope) + } + + fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { + let x_dtype = x.dtype(); + let (b_sz, seq_len, _) = x.shape().r3()?; + + let qkv = self.qkv_proj.forward(x)?; + let n_embd = self.n_head * self.head_dim; + + let q = qkv.i((.., .., ..self.n_head * self.head_dim))?; + let k = qkv.i(( + .., + .., + self.n_head * self.head_dim + ..self.n_head * self.head_dim + self.n_key_value_head * self.head_dim, + ))?; + let v = qkv.i(( + .., + .., + self.n_head * self.head_dim + self.n_key_value_head * self.head_dim.., + ))?; + // todo!("Q {:?} K {:?} V {:?} - x {:?}", q.shape(), k.shape(), v.shape(), x.shape()); + + let q = q + .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)? + .to_dtype(DType::F32)?; + let k = k + .reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))? + .transpose(1, 2)? + .to_dtype(DType::F32)?; + let mut v = v + .reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))? + .transpose(1, 2)? + .to_dtype(DType::F32)?; + + let q = self.apply_rotary_emb(&q, index_pos)?; + let mut k = self.apply_rotary_emb(&k, index_pos)?; + + if self.cache.use_kv_cache { + let mut cache = self.cache.kvs.lock().unwrap(); + if let Some((cache_k, cache_v)) = &cache[block_idx] { + k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?; + v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?; + let k_seq_len = k.dims()[1]; + if k_seq_len > MAX_SEQ_LEN { + k = k + .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .contiguous()? + } + let v_seq_len = v.dims()[1]; + if v_seq_len > 2 * MAX_SEQ_LEN { + v = v + .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .contiguous()? + } + } + cache[block_idx] = Some((k.clone(), v.clone())) + } + + let k = self.repeat_kv(k)?; + let v = self.repeat_kv(v)?; + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; + let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = att.softmax(D::Minus1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let y = att.matmul(&v.contiguous()?)?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; + let y = y.to_dtype(x_dtype)?; + let y = self.o_proj.forward(&y)?; + Ok(y) + } + + fn repeat_kv(&self, x: Tensor) -> Result { + let n_rep = self.n_head / self.n_key_value_head; + if n_rep == 1 { + Ok(x) + } else { + let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().r4()?; + let x = x + .unsqueeze(2)? + .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? + .reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?; + Ok(x) + } + } + + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { + let size_in = cfg.hidden_size; + let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head; + let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head; + + let qkv_proj = TensorParallelColumnLinear::load_multi( + vb.clone(), + &["q_proj", "k_proj", "v_proj"], + comm.clone(), + )?; + let o_proj = TensorParallelRowLinear::load(vb.pp("o_proj"), comm.clone())?; + Ok(Self { + qkv_proj, + o_proj, + n_head: cfg.n_head / comm.world_size(), + n_key_value_head: cfg.n_key_value_head / comm.world_size(), + head_dim: cfg.hidden_size / cfg.n_head, + cache: cache.clone(), + }) + } +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +struct Mlp { + c_fc1: TensorParallelColumnLinear, + c_fc2: TensorParallelColumnLinear, + c_proj: TensorParallelRowLinear, +} + +impl Mlp { + fn new( + c_fc1: TensorParallelColumnLinear, + c_fc2: TensorParallelColumnLinear, + c_proj: TensorParallelRowLinear, + ) -> Self { + Self { + c_fc1, + c_fc2, + c_proj, + } + } + + fn forward(&self, x: &Tensor) -> Result { + let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; + self.c_proj.forward(&x) + } + + fn load(vb: VarBuilder, cfg: &Config, comm: Rc) -> Result { + let h_size = cfg.hidden_size; + let i_size = cfg.intermediate_size; + let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?; + let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?; + let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?; + Ok(Self::new(c_fc1, c_fc2, c_proj)) + } +} + +struct Block { + rms_1: RmsNorm, + attn: CausalSelfAttention, + rms_2: RmsNorm, + mlp: Mlp, +} + +impl Block { + fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { + Self { + rms_1, + attn, + rms_2, + mlp, + } + } + + fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { + let residual = x; + let x = self.rms_1.forward(x)?; + let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; + let residual = &x; + let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; + Ok(x) + } + + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { + let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?; + let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?; + let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?; + let post_attention_layernorm = + RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?; + Ok(Self::new( + input_layernorm, + attn, + post_attention_layernorm, + mlp, + )) + } +} + +pub struct Llama { + wte: Embedding, + blocks: Vec, + ln_f: RmsNorm, + lm_head: Linear, +} + +impl Llama { + fn new(wte: Embedding, blocks: Vec, ln_f: RmsNorm, lm_head: Linear) -> Self { + Self { + wte, + blocks, + ln_f, + lm_head, + } + } + + pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result { + let (_b_sz, seq_len) = x.shape().r2()?; + let mut x = self.wte.forward(x)?; + for (block_idx, block) in self.blocks.iter().enumerate() { + x = block.forward(&x, index_pos, block_idx)?; + } + let x = self.ln_f.forward(&x)?; + let x = x.i((.., seq_len - 1, ..))?; + let logits = self.lm_head.forward(&x)?; + logits.to_dtype(DType::F32) + } + + pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { + let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; + let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?; + let blocks: Vec<_> = (0..cfg.n_layer) + .map(|i| { + Block::load( + vb.pp(&format!("model.layers.{i}")), + cache, + cfg, + comm.clone(), + ) + .unwrap() + }) + .collect(); + + Ok(Self::new(wte, blocks, norm, lm_head)) + } +} diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index 18d9e0e65..953353237 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -14,6 +14,7 @@ readme = "README.md" candle = { path = "../candle-core" } thiserror = { workspace = true } intel-mkl-src = { workspace = true, optional = true } +safetensors = { workspace = true } [dev-dependencies] anyhow = { workspace = true } diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 87dd2a7f2..b02d216b1 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -1,7 +1,6 @@ -use candle::{ - safetensors::{Load, SafeTensors}, - DType, Device, Error, Result, Shape, Tensor, -}; +use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor}; +use safetensors::slice::IndexOp; +use safetensors::tensor::SafeTensors; use std::collections::HashMap; use std::sync::Arc; @@ -71,7 +70,7 @@ impl<'a> TensorData<'a> { #[derive(Clone)] pub struct VarBuilder<'a> { data: Arc>, - path: Vec, + pub path: Vec, } impl<'a> VarBuilder<'a> { @@ -137,6 +136,55 @@ impl<'a> VarBuilder<'a> { } impl<'a> VarBuilder<'a> { + pub fn get_sharded( + &self, + tensor_name: &str, + dim: usize, + rank: usize, + world_size: usize, + ) -> Result { + let data = self.data.as_ref(); + let path = if self.path.is_empty() { + tensor_name.to_string() + } else { + [&self.path.join("."), tensor_name].join(".") + }; + let tensor = match &self.data.tensors { + Tensors::SafeTensorWithRouting { + routing, + safetensors, + } => { + let index = routing.get(&path).ok_or_else(|| { + Error::CannotFindTensor { + path: path.to_string(), + } + .bt() + })?; + + let view = safetensors[*index].tensor(&path)?; + let dtype = view.dtype(); + let mut shape = view.shape().to_vec(); + let size = shape[dim]; + let block_size = size / world_size; + let start = rank * block_size; + let stop = (rank + 1) * block_size; + + let iterator = if dim == 0 { + view.slice(start..stop).unwrap() + } else if dim == 1 { + view.slice((.., start..stop)).unwrap() + } else { + unimplemented!("Get sharded on dimensions != 0 or 1"); + }; + + shape[dim] = block_size; + + Tensor::from_safetensors_slice(iterator, dtype, &shape, &data.device)? + } + _ => unimplemented!(), + }; + Ok(tensor) + } pub fn get>(&self, s: S, tensor_name: &str) -> Result { let data = self.data.as_ref(); let s: Shape = s.into(); diff --git a/candle-wasm-examples/whisper/Cargo.toml b/candle-wasm-examples/whisper/Cargo.toml index 55b507db5..4ebb27884 100644 --- a/candle-wasm-examples/whisper/Cargo.toml +++ b/candle-wasm-examples/whisper/Cargo.toml @@ -15,6 +15,7 @@ candle = { path = "../../candle-core" } candle-nn = { path = "../../candle-nn" } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } +safetensors = { workspace = true } # App crates. anyhow = { workspace = true } diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs index ea64bf029..62eaa16ff 100644 --- a/candle-wasm-examples/whisper/src/worker.rs +++ b/candle-wasm-examples/whisper/src/worker.rs @@ -236,11 +236,11 @@ impl Decoder { let device = Device::Cpu; let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?; - let mel_filters = candle::safetensors::SafeTensors::deserialize(&md.mel_filters)?; + let mel_filters = safetensors::tensor::SafeTensors::deserialize(&md.mel_filters)?; let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?; console_log!("loaded mel filters {:?}", mel_filters.shape()); let mel_filters = mel_filters.flatten_all()?.to_vec1::()?; - let weights = candle::safetensors::SafeTensors::deserialize(&md.weights)?; + let weights = safetensors::tensor::SafeTensors::deserialize(&md.weights)?; let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); let config = Config::tiny_en(); let whisper = Whisper::load(&vb, config)?; From ed58de7551fd3dabffa3e8f817ffa4211e9fe4a5 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 25 Jul 2023 19:29:03 +0000 Subject: [PATCH 2/8] Fixed TP sharded version. --- Cargo.toml | 2 +- candle-core/src/op.rs | 2 +- .../examples/llama_multiprocess/main.rs | 28 ++++---- .../examples/llama_multiprocess/model.rs | 66 ++++++++++++------- 4 files changed, 62 insertions(+), 36 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d9613cdce..b02059ae1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = [ "candle-core", "candle-examples", "candle-nn", - "candle-pyo3", + # "candle-pyo3", "candle-transformers", "candle-wasm-examples/llama2-c", "candle-wasm-examples/whisper", diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 525383b22..83b382cd6 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -103,7 +103,7 @@ pub enum Op { } /// Unary ops that can be defined in user-land. -pub trait CustomOp1: Send + Sync { +pub trait CustomOp1 { // Box does not support const yet, so use a function to get the name. fn name(&self) -> &'static str; diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs index 22c121dd4..f9e874323 100644 --- a/candle-examples/examples/llama_multiprocess/main.rs +++ b/candle-examples/examples/llama_multiprocess/main.rs @@ -247,20 +247,24 @@ fn main() -> Result<()> { let next_token = logits_processor.sample(&logits)?; tokens.push(next_token); new_tokens.push(next_token); - println!("> {:?}", start_gen.elapsed()); + if rank == 0 { + println!("> {:?}", start_gen.elapsed()); + println!( + "{} token: {} '{}'", + index + 1, + next_token, + tokenizer.decode(vec![next_token], true).map_err(E::msg)? + ); + } + } + let dt = start_gen.elapsed(); + if rank == 0 { println!( - "{} token: {} '{}'", - index + 1, - next_token, - tokenizer.decode(vec![next_token], true).map_err(E::msg)? + "{} tokens generated ({} token/s)\n----\n{}\n----", + args.sample_len, + args.sample_len as f64 / dt.as_secs_f64(), + tokenizer.decode(new_tokens, true).map_err(E::msg)? ); } - let dt = start_gen.elapsed(); - println!( - "{} tokens generated ({} token/s)\n----\n{}\n----", - args.sample_len, - args.sample_len as f64 / dt.as_secs_f64(), - tokenizer.decode(new_tokens, true).map_err(E::msg)? - ); Ok(()) } diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index 4e46b5267..e902734f2 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -1,6 +1,6 @@ -use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle::backend::BackendStorage; +use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D}; use candle_nn::{Embedding, Linear, VarBuilder}; -use cudarc::driver::safe::CudaSlice; use cudarc::nccl::safe::{Comm, ReduceOp}; use half::f16; use std::collections::HashMap; @@ -27,14 +27,42 @@ struct TensorParallelRowLinear { comm: Rc, } +struct AllReduce { + comm: Rc, +} + +impl CustomOp1 for AllReduce { + fn name(&self) -> &'static str { + "allreduce" + } + + fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> { + todo!("implement allreduce for cpu is not necessary for single node"); + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s: &candle::CudaStorage, + l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + use candle::cuda_backend::WrapErr; + let elem_count = l.shape().elem_count(); + let dev = s.device().clone(); + let s = s.as_cuda_slice::()?; + // let s = match l.contiguous_offsets() { + // None => Err(Error::Wrapped("input has to be contiguous".into()))?, + // Some((o1, o2)) => s.slice(o1..o2), + // }; + let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; + self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap(); + let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev); + Ok((dst, l.shape().clone())) + } +} + fn all_reduce_sum(x: &Tensor, comm: &Rc) -> Result { - Ok(x.clone()) - // let n = x.shape().elem_count(); - // let cuda_slice: CudaSlice = x.try_into()?; - // let dev = cuda_slice.device(); - // let mut slice_receive = dev.alloc_zeros(n).unwrap(); - // comm.all_reduce(cuda_slice, &mut slice_receive, &ReduceOp::Sum).unwrap(); - // Tensor::from_raw_storage(slice_receive, x.shape()) + x.custom_op1(AllReduce { comm: comm.clone() }) } impl TensorParallelRowLinear { @@ -187,11 +215,11 @@ impl RmsNorm { let in_dtype = x.dtype(); // This is a no-op if x's dtype is already f32. let x = x.to_dtype(DType::F32)?; - let (b_sz, seq_len, hidden_size) = x.shape().r3()?; + let (b_sz, seq_len, hidden_size) = x.shape().dims3()?; let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?; - let size = self.scale.shape().r1()?; + let size = self.scale.shape().dims1()?; let scale = self .scale .to_dtype(DType::F32)? @@ -213,7 +241,7 @@ struct CausalSelfAttention { impl CausalSelfAttention { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { - let (b_sz, _, seq_len, n_embd) = x.shape().r4()?; + let (b_sz, _, seq_len, n_embd) = x.shape().dims4()?; let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?; @@ -227,7 +255,7 @@ impl CausalSelfAttention { fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { let x_dtype = x.dtype(); - let (b_sz, seq_len, _) = x.shape().r3()?; + let (b_sz, seq_len, _) = x.shape().dims3()?; let qkv = self.qkv_proj.forward(x)?; let n_embd = self.n_head * self.head_dim; @@ -302,7 +330,7 @@ impl CausalSelfAttention { if n_rep == 1 { Ok(x) } else { - let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().r4()?; + let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().dims4()?; let x = x .unsqueeze(2)? .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? @@ -312,10 +340,6 @@ impl CausalSelfAttention { } fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { - let size_in = cfg.hidden_size; - let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head; - let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head; - let qkv_proj = TensorParallelColumnLinear::load_multi( vb.clone(), &["q_proj", "k_proj", "v_proj"], @@ -364,9 +388,7 @@ impl Mlp { self.c_proj.forward(&x) } - fn load(vb: VarBuilder, cfg: &Config, comm: Rc) -> Result { - let h_size = cfg.hidden_size; - let i_size = cfg.intermediate_size; + fn load(vb: VarBuilder, _cfg: &Config, comm: Rc) -> Result { let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?; let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?; let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?; @@ -433,7 +455,7 @@ impl Llama { } pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result { - let (_b_sz, seq_len) = x.shape().r2()?; + let (_b_sz, seq_len) = x.shape().dims2()?; let mut x = self.wte.forward(x)?; for (block_idx, block) in self.blocks.iter().enumerate() { x = block.forward(&x, index_pos, block_idx)?; From b7814f66b4844f2f0c5d35aed465f4f40a5f6845 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 25 Jul 2023 22:44:30 +0200 Subject: [PATCH 3/8] PyO3 is back. --- Cargo.toml | 2 +- candle-pyo3/src/lib.rs | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index b02059ae1..d9613cdce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = [ "candle-core", "candle-examples", "candle-nn", - # "candle-pyo3", + "candle-pyo3", "candle-transformers", "candle-wasm-examples/llama2-c", "candle-wasm-examples/whisper", diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 136f8a4f7..c81cc7136 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -14,6 +14,9 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr { #[pyclass(name = "Tensor")] struct PyTensor(Tensor); +unsafe impl Send for PyTensor {} +unsafe impl Sync for PyTensor {} + impl std::ops::Deref for PyTensor { type Target = Tensor; From 1553b58fe59a29fe808b9b4d43a6502046ce26dd Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 25 Jul 2023 22:53:09 +0200 Subject: [PATCH 4/8] Tensor are not necessarily sendable (CustomOp1). --- candle-pyo3/src/lib.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index c81cc7136..6e206688f 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -11,12 +11,9 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr { } #[derive(Clone)] -#[pyclass(name = "Tensor")] +#[pyclass(name = "Tensor", unsendable)] struct PyTensor(Tensor); -unsafe impl Send for PyTensor {} -unsafe impl Sync for PyTensor {} - impl std::ops::Deref for PyTensor { type Target = Tensor; From 7c7e6ba201d0270f5ac689c20f16f59e00ed4d01 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 26 Jul 2023 11:16:04 +0200 Subject: [PATCH 5/8] Removing inner dependency on safetensors. --- candle-core/src/safetensors.rs | 27 +++++++++---------- .../examples/llama_multiprocess/model.rs | 23 ++++++++-------- candle-nn/src/var_builder.rs | 10 ++++--- candle-wasm-examples/whisper/Cargo.toml | 2 +- 4 files changed, 30 insertions(+), 32 deletions(-) diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index e81fe184e..dee57b37c 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -1,7 +1,6 @@ use crate::{DType, Device, Error, Result, Tensor, WithDType}; -use safetensors::slice::SliceIterator; use safetensors::tensor as st; -use safetensors::tensor::{Dtype, SafeTensors}; +use safetensors::tensor::SafeTensors; use std::borrow::Cow; impl From for st::Dtype { @@ -118,26 +117,24 @@ impl<'a> Load for st::TensorView<'a> { } impl Tensor { - pub fn from_safetensors_slice( - iterator: SliceIterator, - dtype: Dtype, + pub fn from_raw_buffer( + data: &[u8], + dtype: DType, shape: &[usize], device: &Device, ) -> Result { - let data: Vec = iterator.into_iter().flatten().cloned().collect(); match dtype { - st::Dtype::U8 => convert_slice::(&data, shape, device), - st::Dtype::U32 => convert_slice::(&data, shape, device), - st::Dtype::BF16 => convert_slice::(&data, shape, device), - st::Dtype::F16 => convert_slice::(&data, shape, device), - st::Dtype::F32 => convert_slice::(&data, shape, device), - st::Dtype::F64 => convert_slice::(&data, shape, device), - dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), + DType::U8 => convert_slice::(data, shape, device), + DType::U32 => convert_slice::(data, shape, device), + DType::BF16 => convert_slice::(data, shape, device), + DType::F16 => convert_slice::(data, shape, device), + DType::F32 => convert_slice::(data, shape, device), + DType::F64 => convert_slice::(data, shape, device), } } } -pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { +fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { match view.dtype() { st::Dtype::U8 => convert_::(view, device), st::Dtype::U32 => convert_::(view, device), @@ -149,7 +146,7 @@ pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { } } -pub fn convert_back(tensor: &Tensor) -> Result> { +fn convert_back(tensor: &Tensor) -> Result> { // TODO: This makes an unnecessary copy when the tensor is on the cpu. let tensor = tensor.flatten_all()?; match tensor.dtype() { diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index e902734f2..becaa8798 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -4,7 +4,6 @@ use candle_nn::{Embedding, Linear, VarBuilder}; use cudarc::nccl::safe::{Comm, ReduceOp}; use half::f16; use std::collections::HashMap; -use std::rc::Rc; use std::sync::{Arc, Mutex}; use super::MAX_SEQ_LEN; @@ -24,11 +23,11 @@ impl TensorParallelColumnLinear { struct TensorParallelRowLinear { linear: Linear, - comm: Rc, + comm: Arc, } struct AllReduce { - comm: Rc, + comm: Arc, } impl CustomOp1 for AllReduce { @@ -61,12 +60,12 @@ impl CustomOp1 for AllReduce { } } -fn all_reduce_sum(x: &Tensor, comm: &Rc) -> Result { +fn all_reduce_sum(x: &Tensor, comm: &Arc) -> Result { x.custom_op1(AllReduce { comm: comm.clone() }) } impl TensorParallelRowLinear { - fn new(linear: Linear, comm: Rc) -> Self { + fn new(linear: Linear, comm: Arc) -> Self { Self { linear, comm } } fn forward(&self, x: &Tensor) -> Result { @@ -76,14 +75,14 @@ impl TensorParallelRowLinear { } impl TensorParallelColumnLinear { - fn load(vb: VarBuilder, comm: Rc) -> Result { + fn load(vb: VarBuilder, comm: Arc) -> Result { let rank = comm.rank(); let size = comm.world_size(); let weight = vb.get_sharded("weight", 0, rank, size)?; Ok(Self::new(Linear::new(weight, None))) } - fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc) -> Result { + fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Arc) -> Result { let rank = comm.rank(); let size = comm.world_size(); let weights: Vec<_> = prefixes @@ -96,7 +95,7 @@ impl TensorParallelColumnLinear { } impl TensorParallelRowLinear { - fn load(vb: VarBuilder, comm: Rc) -> Result { + fn load(vb: VarBuilder, comm: Arc) -> Result { let rank = comm.rank(); let size = comm.world_size(); let weight = vb.get_sharded("weight", 1, rank, size)?; @@ -339,7 +338,7 @@ impl CausalSelfAttention { } } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc) -> Result { let qkv_proj = TensorParallelColumnLinear::load_multi( vb.clone(), &["q_proj", "k_proj", "v_proj"], @@ -388,7 +387,7 @@ impl Mlp { self.c_proj.forward(&x) } - fn load(vb: VarBuilder, _cfg: &Config, comm: Rc) -> Result { + fn load(vb: VarBuilder, _cfg: &Config, comm: Arc) -> Result { let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?; let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?; let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?; @@ -422,7 +421,7 @@ impl Block { Ok(x) } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc) -> Result { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?; let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?; let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?; @@ -466,7 +465,7 @@ impl Llama { logits.to_dtype(DType::F32) } - pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { + pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc) -> Result { let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?; diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index b02d216b1..1466f6d01 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -1,6 +1,5 @@ use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor}; -use safetensors::slice::IndexOp; -use safetensors::tensor::SafeTensors; +use safetensors::{slice::IndexOp, tensor::SafeTensors}; use std::collections::HashMap; use std::sync::Arc; @@ -70,7 +69,7 @@ impl<'a> TensorData<'a> { #[derive(Clone)] pub struct VarBuilder<'a> { data: Arc>, - pub path: Vec, + path: Vec, } impl<'a> VarBuilder<'a> { @@ -179,7 +178,10 @@ impl<'a> VarBuilder<'a> { shape[dim] = block_size; - Tensor::from_safetensors_slice(iterator, dtype, &shape, &data.device)? + let dtype: DType = dtype.try_into()?; + + let raw: Vec = iterator.into_iter().flatten().cloned().collect(); + Tensor::from_raw_buffer(&raw, dtype, &shape, &data.device)? } _ => unimplemented!(), }; diff --git a/candle-wasm-examples/whisper/Cargo.toml b/candle-wasm-examples/whisper/Cargo.toml index 4ebb27884..b51d40527 100644 --- a/candle-wasm-examples/whisper/Cargo.toml +++ b/candle-wasm-examples/whisper/Cargo.toml @@ -15,7 +15,6 @@ candle = { path = "../../candle-core" } candle-nn = { path = "../../candle-nn" } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } -safetensors = { workspace = true } # App crates. anyhow = { workspace = true } @@ -24,6 +23,7 @@ rand = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } wav = { workspace = true } +safetensors = { workspace = true } # Wasm specific crates. getrandom = { version = "0.2", features = ["js"] } From 25a2086e8f4cc23fada32a44607d3b8550916ebe Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 26 Jul 2023 10:22:40 +0000 Subject: [PATCH 6/8] Putting back Send + Sync --- candle-core/src/op.rs | 2 +- .../examples/llama_multiprocess/model.rs | 30 ++++++++++++------- candle-pyo3/src/lib.rs | 2 +- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 83b382cd6..525383b22 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -103,7 +103,7 @@ pub enum Op { } /// Unary ops that can be defined in user-land. -pub trait CustomOp1 { +pub trait CustomOp1: Send + Sync { // Box does not support const yet, so use a function to get the name. fn name(&self) -> &'static str; diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index becaa8798..bcf6ed2bf 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -4,6 +4,7 @@ use candle_nn::{Embedding, Linear, VarBuilder}; use cudarc::nccl::safe::{Comm, ReduceOp}; use half::f16; use std::collections::HashMap; +use std::rc::Rc; use std::sync::{Arc, Mutex}; use super::MAX_SEQ_LEN; @@ -23,13 +24,20 @@ impl TensorParallelColumnLinear { struct TensorParallelRowLinear { linear: Linear, - comm: Arc, + comm: Rc, } struct AllReduce { - comm: Arc, + comm: Rc, } +/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html +/// But for this example purposes, this will work +unsafe impl Sync for AllReduce {} +/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html +/// But for this example purposes, this will work +unsafe impl Send for AllReduce {} + impl CustomOp1 for AllReduce { fn name(&self) -> &'static str { "allreduce" @@ -60,12 +68,12 @@ impl CustomOp1 for AllReduce { } } -fn all_reduce_sum(x: &Tensor, comm: &Arc) -> Result { +fn all_reduce_sum(x: &Tensor, comm: &Rc) -> Result { x.custom_op1(AllReduce { comm: comm.clone() }) } impl TensorParallelRowLinear { - fn new(linear: Linear, comm: Arc) -> Self { + fn new(linear: Linear, comm: Rc) -> Self { Self { linear, comm } } fn forward(&self, x: &Tensor) -> Result { @@ -75,14 +83,14 @@ impl TensorParallelRowLinear { } impl TensorParallelColumnLinear { - fn load(vb: VarBuilder, comm: Arc) -> Result { + fn load(vb: VarBuilder, comm: Rc) -> Result { let rank = comm.rank(); let size = comm.world_size(); let weight = vb.get_sharded("weight", 0, rank, size)?; Ok(Self::new(Linear::new(weight, None))) } - fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Arc) -> Result { + fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc) -> Result { let rank = comm.rank(); let size = comm.world_size(); let weights: Vec<_> = prefixes @@ -95,7 +103,7 @@ impl TensorParallelColumnLinear { } impl TensorParallelRowLinear { - fn load(vb: VarBuilder, comm: Arc) -> Result { + fn load(vb: VarBuilder, comm: Rc) -> Result { let rank = comm.rank(); let size = comm.world_size(); let weight = vb.get_sharded("weight", 1, rank, size)?; @@ -338,7 +346,7 @@ impl CausalSelfAttention { } } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc) -> Result { + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { let qkv_proj = TensorParallelColumnLinear::load_multi( vb.clone(), &["q_proj", "k_proj", "v_proj"], @@ -387,7 +395,7 @@ impl Mlp { self.c_proj.forward(&x) } - fn load(vb: VarBuilder, _cfg: &Config, comm: Arc) -> Result { + fn load(vb: VarBuilder, _cfg: &Config, comm: Rc) -> Result { let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?; let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?; let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?; @@ -421,7 +429,7 @@ impl Block { Ok(x) } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc) -> Result { + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?; let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?; let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?; @@ -465,7 +473,7 @@ impl Llama { logits.to_dtype(DType::F32) } - pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc) -> Result { + pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?; diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 6e206688f..136f8a4f7 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -11,7 +11,7 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr { } #[derive(Clone)] -#[pyclass(name = "Tensor", unsendable)] +#[pyclass(name = "Tensor")] struct PyTensor(Tensor); impl std::ops::Deref for PyTensor { From 952eca6b540078b1f30b58d9eb930f8e32d903cb Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 27 Jul 2023 16:59:32 +0200 Subject: [PATCH 7/8] Fixing slice errors + comments. --- candle-core/src/error.rs | 7 +++++++ candle-nn/src/var_builder.rs | 25 ++++++++++++++++++++++--- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index f9e69122a..30d062391 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -79,6 +79,13 @@ pub enum Error { nth_shape: Shape, }, + #[error("Cannot divide tensor of shape {shape:?} equally along dim {dim} into {n_parts}")] + ShapeMismatchSplit { + shape: Shape, + dim: usize, + n_parts: usize, + }, + #[error("{op} can only be performed on a single dimension")] OnlySingleDimension { op: &'static str, dims: Vec }, diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 1466f6d01..3133f210d 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -135,6 +135,17 @@ impl<'a> VarBuilder<'a> { } impl<'a> VarBuilder<'a> { + /// Get part of a tensor, typically used to do Tensor Parallelism sharding. + /// + /// If the tensor is of size (1024, 1024). + /// + /// `dim` corresponds to the dimension to slice into + /// `rank` is the rank of the current process + /// `world_size` is the total number of ranks in the process group + /// + /// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))` + /// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))` + /// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))` pub fn get_sharded( &self, tensor_name: &str, @@ -164,16 +175,24 @@ impl<'a> VarBuilder<'a> { let dtype = view.dtype(); let mut shape = view.shape().to_vec(); let size = shape[dim]; + + if size % world_size != 0 { + return Err(Error::ShapeMismatchSplit { + shape: shape.into(), + dim, + n_parts: world_size, + }); + } let block_size = size / world_size; let start = rank * block_size; let stop = (rank + 1) * block_size; let iterator = if dim == 0 { - view.slice(start..stop).unwrap() + view.slice(start..stop).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))? } else if dim == 1 { - view.slice((.., start..stop)).unwrap() + view.slice((.., start..stop)).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))? } else { - unimplemented!("Get sharded on dimensions != 0 or 1"); + candle::bail!("Get sharded on dimensions != 0 or 1") }; shape[dim] = block_size; From 8435a99edd6f5aa7bc86d0ffdaf23e322a93f626 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 27 Jul 2023 20:11:57 +0200 Subject: [PATCH 8/8] Added comment about offsets. --- candle-nn/src/var_builder.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 3133f210d..5c222bf6f 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -187,6 +187,9 @@ impl<'a> VarBuilder<'a> { let start = rank * block_size; let stop = (rank + 1) * block_size; + // Everything is expressed in tensor dimension + // bytes offsets is handled automatically for safetensors. + let iterator = if dim == 0 { view.slice(start..stop).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))? } else if dim == 1 {