diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index fa5c620a48..b8ba1cd1a7 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -36,6 +36,7 @@ serde_json = { workspace = true } symphonia = { version = "0.5.3", features = ["all"], optional = true } tokenizers = { workspace = true, features = ["onig"] } cpal= { version = "0.15.2", optional = true } +futures = "0.3.30" [dev-dependencies] anyhow = { workspace = true } @@ -73,6 +74,10 @@ depth_anything_v2 = ["palette", "enterpolation"] name = "llama_multiprocess" required-features = ["cuda", "nccl", "flash-attn"] +[[example]] +name = "llama_multinode" +required-features = ["cuda", "nccl"] + [[example]] name = "reinforcement-learning" required-features = ["pyo3"] diff --git a/candle-examples/examples/llama_multinode/README.md b/candle-examples/examples/llama_multinode/README.md new file mode 100644 index 0000000000..d580a97116 --- /dev/null +++ b/candle-examples/examples/llama_multinode/README.md @@ -0,0 +1,120 @@ +# Llama Multinode + +This project implements a distributed version of the Llama language model using Rust, CUDA, and NCCL for multi-node, multi-GPU inference. + +## TL;DR + +To quickly set up and run the project on the master node, use this single command: + +```bash +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y && \ +. "$HOME/.cargo/env" && \ +git clone -b chore/llama_multinode https://github.com/b0xtch/candle.git && \ +cd candle && \ +git submodule update --init --recursive && \ +pip install --upgrade huggingface_hub && \ +echo '' | huggingface-cli login && \ +source nccl_env_vars.sh && \ +RUST_BACKTRACE=1 cargo run --example llama_multinode --release --features="cuda nccl" -- \ + --num-nodes 2 \ + --node-rank 0 \ + --master-addr 10.0.10.30 \ + --master-port 29500 \ + --num-gpus-per-node 1 \ + --model-id "meta-llama/Meta-Llama-3-8B" \ + --dtype bf16 \ + --prompt "Once upon a time" +``` + +Note: Replace `10.0.10.30` with the private IP of your master node. + +## Prerequisites + +- CUDA-capable GPUs +- CUDA Toolkit 12.1 or later +- Rust toolchain +- Docker (for containerized setup) + +## Setup on AWS Nodes + +Follow these steps to set up and run the project on AWS nodes: + +1. Install Rust: + ```bash + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + source "$HOME/.cargo/env" + ``` + +2. Clone the repository and update submodules: + ```bash + git clone -b chore/llama_multinode https://github.com/b0xtch/candle.git + cd candle + git submodule update --init --recursive + ``` + +3. Install and set up Hugging Face CLI: + ```bash + pip install --upgrade huggingface_hub + echo '' | huggingface-cli login + ``` + +4. Set up NCCL environment variables: + ```bash + source nccl_env_vars.sh + ``` + +## Running the Distributed Llama Model + +### On the Master Node + +Run the following command, replacing `10.0.10.30` with the private IP of your master node: + +```bash +RUST_BACKTRACE=1 cargo run --example llama_multinode --release --features="cuda nccl" -- \ + --num-nodes 2 \ + --node-rank 0 \ + --master-addr 10.0.10.30 \ + --master-port 29500 \ + --num-gpus-per-node 1 \ + --model-id "meta-llama/Meta-Llama-3-8B" \ + --dtype bf16 \ + --prompt "Once upon a time" +``` + +### On Worker Nodes + +Run the following command on each worker node, replacing `54.201.229.196` with the public IP of your master node: + +```bash +RUST_BACKTRACE=1 cargo run --example llama_multinode --release --features="cuda nccl" -- \ + --num-nodes 2 \ + --node-rank 1 \ + --master-addr 10.0.10.30 \ + --master-port 29500 \ + --num-gpus-per-node 1 \ + --model-id "meta-llama/Meta-Llama-3-8B" \ + --dtype bf16 \ + --prompt "Once upon a time" +``` + +Note: Increment the `--node-rank` for each additional worker node. + +## Troubleshooting + +- Ensure all nodes can communicate with each other over the specified port (29500 in this example). +- Check that the CUDA and NCCL versions are compatible across all nodes. +- Verify that the Hugging Face token has the necessary permissions to access the model. + +## Misc + +> Optional: Run the NVIDIA CUDA container: + ```bash + docker run -it --gpus all nvidia/cuda:12.1.1-devel-ubuntu20.04 /bin/bash + ``` + +> Optional: Inside the container, install necessary dependencies: + ```bash + apt-get -y update && \ + apt-get -y install curl git pkg-config libssl-dev + ``` +--- \ No newline at end of file diff --git a/candle-examples/examples/llama_multinode/main.rs b/candle-examples/examples/llama_multinode/main.rs new file mode 100644 index 0000000000..a85f83a876 --- /dev/null +++ b/candle-examples/examples/llama_multinode/main.rs @@ -0,0 +1,265 @@ +// An implementation of LLaMA https://github.com/facebookresearch/llama +// +// This is based on nanoGPT in a similar way to: +// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py +// +// The tokenizer config can be retrieved from: +// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use anyhow::{bail, Error as E, Result}; +use clap::{Parser, ValueEnum}; + +use candle::{DType, Device, Tensor}; +use candle_transformers::generation::LogitsProcessor; +use cudarc::driver::CudaDevice; +use cudarc::nccl::safe::{Comm, Id}; +use futures::future::join_all; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use std::io::Write; +use std::net::{IpAddr, SocketAddr}; +use std::rc::Rc; +use tokenizers::Tokenizer; + +mod model; +use model::{Config, Llama}; + +mod nccl_id_distribution; +use nccl_id_distribution::{get_nccl_id_from_server, run_nccl_id_server}; + +const MAX_SEQ_LEN: usize = 4096; +const DEFAULT_PROMPT: &str = "My favorite theorem is "; + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Which { + V2_7b, + V2_70b, + V3_8b, + V3_70b, +} + +#[derive(Parser, Debug, Clone)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(long)] + num_nodes: usize, + + #[arg(long)] + node_rank: usize, + + #[arg(long)] + master_addr: IpAddr, + + #[arg(long)] + master_port: u16, + + #[arg(long)] + num_gpus_per_node: usize, + + #[arg(long, default_value_t = 0.8)] + temperature: f64, + + #[arg(long)] + top_p: Option, + + #[arg(long, default_value_t = 299792458)] + seed: u64, + + #[arg(long, default_value_t = 100)] + sample_len: usize, + + #[arg(long)] + no_kv_cache: bool, + + #[arg(long)] + prompt: Option, + + #[arg(long)] + model_id: Option, + + #[arg(long)] + revision: Option, + + #[arg(long)] + dtype: Option, + + #[arg(long, default_value = "v3-8b")] + which: Which, +} + +#[tokio::main] +async fn main() -> Result<()> { + let args = Args::parse(); + + let dtype = match args.dtype.as_deref() { + Some("f16") => DType::F16, + Some("bf16") => DType::BF16, + Some("f32") => DType::F32, + Some(dtype) => bail!("Unsupported dtype {dtype}"), + None => match args.which { + Which::V2_7b | Which::V2_70b => DType::F16, + Which::V3_8b | Which::V3_70b => DType::BF16, + }, + }; + + let world_size = args.num_nodes * args.num_gpus_per_node; + let global_rank = args.node_rank * args.num_gpus_per_node; + let num_workers = args.num_nodes - 1; + + println!( + "Node rank: {}, Total nodes: {}", + args.node_rank, args.num_nodes + ); + + // Initialize NCCL + let unique_id = if args.node_rank == 0 { + println!("Initializing NCCL ID Server on master node"); + let id = Id::new().map_err(|e| anyhow::anyhow!("NCCL error: {:?}", e))?; + let id_clone = id.clone(); + tokio::spawn(async move { + if let Err(e) = run_nccl_id_server(args.master_port, id_clone, num_workers).await { + eprintln!("NCCL ID Server error: {:?}", e); + } + }); + + id + } else { + println!("Worker node connecting to NCCL ID Server"); + let server_addr = SocketAddr::new(args.master_addr, args.master_port); + get_nccl_id_from_server(server_addr).await? + }; + + println!("NCCL ID initialized, starting GPU processes"); + + let handles: Vec<_> = (0..args.num_gpus_per_node) + .map(|local_rank| { + let rank = global_rank + local_rank; + let args_clone = args.clone(); + let unique_id_clone = unique_id.clone(); + tokio::spawn(async move { + if let Err(e) = + run_gpu_process(args_clone, dtype, rank, world_size, unique_id_clone).await + { + eprintln!("GPU process error for rank {}: {:?}", rank, e); + } + }) + }) + .collect(); + + let results = join_all(handles).await; + + for result in results { + if let Err(e) = result { + eprintln!("Task join error: {:?}", e); + } + } + + Ok(()) +} + +async fn run_gpu_process( + args: Args, + dtype: DType, + rank: usize, + world_size: usize, + unique_id: Id, +) -> Result<()> { + let num_devices = CudaDevice::count()? as usize; + println!("Available CUDA devices: {}", num_devices); + + let local_rank = rank % num_devices; + println!("Using local rank {} for global rank {}", local_rank, rank); + + let device = CudaDevice::new(local_rank)?; + let comm = match Comm::from_rank(device, rank, world_size, unique_id) { + Ok(comm) => Rc::new(comm), + Err(err) => anyhow::bail!("nccl error {:?}", err.0), + }; + + println!("Rank {rank:?} spawned"); + + let api = Api::new()?; + let model_id = match args.model_id { + Some(model) => model, + None => match args.which { + Which::V2_7b => "meta-llama/Llama-2-7b-hf".to_string(), + Which::V2_70b => "meta-llama/Llama-2-70b-hf".to_string(), + Which::V3_8b => "meta-llama/Meta-Llama-3-8B".to_string(), + Which::V3_70b => "meta-llama/Meta-Llama-3-70B".to_string(), + }, + }; + println!("loading the model weights from {model_id}"); + let revision = args.revision.unwrap_or("main".to_string()); + let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); + let config_filename = api.get("config.json")?; + let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; + let tokenizer_filename = api.get("tokenizer.json")?; + let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?; + + let device = Device::new_cuda(rank)?; + let cache = model::Cache::new(dtype, &config, &device)?; + + println!("building the model"); + let vb = unsafe { + candle_nn::var_builder::ShardedSafeTensors::var_builder(&filenames, dtype, &device)? + }; + let llama = Llama::load(vb, &cache, &config, comm)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str()); + let mut tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer); + + println!("starting the inference loop"); + let temperature = if args.temperature <= 0. { + None + } else { + Some(args.temperature) + }; + let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p); + let mut new_tokens = vec![]; + let mut start_gen = std::time::Instant::now(); + let mut index_pos = 0; + for index in 0..args.sample_len { + // Only start timing at the second token as processing the first token waits for all the + // weights to be loaded in an async way. + if index == 1 { + start_gen = std::time::Instant::now() + }; + let context_size = if index > 0 { 1 } else { tokens.len() }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; + let logits = llama.forward(&input, index_pos)?; + let logits = logits.squeeze(0)?; + index_pos += ctxt.len(); + + let next_token = logits_processor.sample(&logits)?; + tokens.push(next_token); + new_tokens.push(next_token); + if Some(next_token) == config.eos_token_id { + break; + } + if rank == 0 { + if let Some(t) = tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + } + } + println!(); + if rank == 0 { + let dt = start_gen.elapsed(); + println!( + "\n\n{} tokens generated ({} token/s)\n", + args.sample_len, + (args.sample_len - 1) as f64 / dt.as_secs_f64(), + ); + } + Ok(()) +} diff --git a/candle-examples/examples/llama_multinode/model.rs b/candle-examples/examples/llama_multinode/model.rs new file mode 100644 index 0000000000..1a88d210aa --- /dev/null +++ b/candle-examples/examples/llama_multinode/model.rs @@ -0,0 +1,440 @@ +use candle::backend::BackendStorage; +use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D}; +use candle_nn::var_builder::ShardedVarBuilder as VarBuilder; +use candle_nn::{Embedding, Linear, Module, RmsNorm}; +use cudarc::nccl::safe::{Comm, ReduceOp}; +use std::collections::HashMap; +use std::rc::Rc; +use std::sync::{Arc, Mutex}; + +use super::MAX_SEQ_LEN; + +pub type Config = candle_transformers::models::llama::LlamaConfig; + +struct TensorParallelColumnLinear { + linear: Linear, +} + +impl TensorParallelColumnLinear { + fn new(linear: Linear) -> Self { + Self { linear } + } + fn forward(&self, x: &Tensor) -> Result { + self.linear.forward(x) + } +} + +struct TensorParallelRowLinear { + linear: Linear, + all_reduce: AllReduce, +} + +struct AllReduce { + comm: Rc, +} + +/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html +/// But for this example purposes, this will work +unsafe impl Sync for AllReduce {} +unsafe impl Send for AllReduce {} + +impl CustomOp1 for AllReduce { + fn name(&self) -> &'static str { + "allreduce" + } + + fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> { + candle::bail!("AllReduce is never used on cpu") + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s: &candle::CudaStorage, + l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + use candle::cuda_backend::WrapErr; + use cudarc::driver::DeviceSlice; + use half::{bf16, f16}; + + let elem_count = l.shape().elem_count(); + let dev = s.device().clone(); + let dst = match s.dtype() { + DType::BF16 => { + let s = s.as_cuda_slice::()?; + let s = match l.contiguous_offsets() { + Some((0, l)) if l == s.len() => s, + Some(_) | None => candle::bail!("input has to be contiguous"), + }; + let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; + self.comm + .all_reduce(s, &mut dst, &ReduceOp::Sum) + .map_err(candle::Error::debug)?; + candle::CudaStorage::wrap_cuda_slice(dst, dev) + } + DType::F16 => { + let s = s.as_cuda_slice::()?; + let s = match l.contiguous_offsets() { + Some((0, l)) if l == s.len() => s, + Some(_) | None => candle::bail!("input has to be contiguous"), + }; + let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; + self.comm + .all_reduce(s, &mut dst, &ReduceOp::Sum) + .map_err(candle::Error::debug)?; + candle::CudaStorage::wrap_cuda_slice(dst, dev) + } + dtype => candle::bail!("unsupported dtype {dtype:?}"), + }; + Ok((dst, l.shape().clone())) + } +} + +impl TensorParallelRowLinear { + fn new(linear: Linear, comm: Rc) -> Self { + let all_reduce = AllReduce { comm }; + Self { linear, all_reduce } + } + fn forward(&self, x: &Tensor) -> Result { + self.linear.forward(x)?.apply_op1_no_bwd(&self.all_reduce) + } +} + +fn shard(dim: usize, rank: usize, world_size: usize) -> candle_nn::var_builder::Shard { + candle_nn::var_builder::Shard { + dim, + rank, + world_size, + } +} + +impl TensorParallelColumnLinear { + fn load(vb: VarBuilder, comm: Rc) -> Result { + let rank = comm.rank(); + let size = comm.world_size(); + let weight = vb.get_with_hints((), "weight", shard(0, rank, size))?; + Ok(Self::new(Linear::new(weight, None))) + } + + fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc) -> Result { + let rank = comm.rank(); + let size = comm.world_size(); + let weights: Vec<_> = prefixes + .iter() + .map(|p| vb.pp(p).get_with_hints((), "weight", shard(0, rank, size))) + .collect::>>()?; + let weight = Tensor::cat(&weights, 0)?; + Ok(Self::new(Linear::new(weight, None))) + } +} + +impl TensorParallelRowLinear { + fn load(vb: VarBuilder, comm: Rc) -> Result { + let rank = comm.rank(); + let size = comm.world_size(); + let weight = vb.get_with_hints((), "weight", shard(1, rank, size))?; + Ok(Self::new(Linear::new(weight, None), comm)) + } +} + +#[derive(Clone)] +pub struct Cache { + masks: Arc>>, + #[allow(clippy::type_complexity)] + kvs: Arc>>>, + cos: Tensor, + sin: Tensor, + device: Device, +} + +impl Cache { + pub fn new(dtype: DType, config: &Config, device: &Device) -> Result { + // precompute freqs_cis + let n_elem = config.hidden_size / config.num_attention_heads; + let theta: Vec<_> = (0..n_elem) + .step_by(2) + .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), device)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + // This is different from the paper, see: + // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 + let cos = idx_theta.cos()?.to_dtype(dtype)?; + let sin = idx_theta.sin()?.to_dtype(dtype)?; + Ok(Self { + masks: Arc::new(Mutex::new(HashMap::new())), + kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])), + cos, + sin, + device: device.clone(), + }) + } + + fn mask(&self, t: usize) -> Result { + let mut masks = self.masks.lock().unwrap(); + if let Some(mask) = masks.get(&t) { + Ok(mask.clone()) + } else { + // TODO: If we support bool or u8 tensors, this would be better. + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u32::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; + masks.insert(t, mask.clone()); + Ok(mask) + } + } +} + +fn silu(xs: &Tensor) -> Result { + xs / (xs.neg()?.exp()? + 1.0)? +} + +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { + let weight = vb.get((size2, size1), "weight")?; + Ok(Linear::new(weight, None)) +} + +fn embedding(cfg: &Config, vb: VarBuilder) -> Result { + let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?; + Ok(Embedding::new(embeddings, cfg.hidden_size)) +} + +struct CausalSelfAttention { + qkv_proj: TensorParallelColumnLinear, + o_proj: TensorParallelRowLinear, + num_attention_heads: usize, + num_key_value_heads: usize, + head_dim: usize, + cache: Cache, +} + +impl CausalSelfAttention { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { + let (_b_sz, _, seq_len, _hidden_size) = x.shape().dims4()?; + let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; + let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; + candle_nn::rotary_emb::rope(x, &cos, &sin) + } + + fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { + let (b_sz, seq_len, _) = x.shape().dims3()?; + let x_dtype = x.dtype(); + + let qkv = self.qkv_proj.forward(x)?; + let hidden_size = self.num_attention_heads * self.head_dim; + + let q = qkv.i((.., .., ..self.num_attention_heads * self.head_dim))?; + let k = qkv.i(( + .., + .., + self.num_attention_heads * self.head_dim + ..self.num_attention_heads * self.head_dim + + self.num_key_value_heads * self.head_dim, + ))?; + let v = qkv.i(( + .., + .., + self.num_attention_heads * self.head_dim + self.num_key_value_heads * self.head_dim.., + ))?; + // todo!("Q {:?} K {:?} V {:?} - x {:?}", q.shape(), k.shape(), v.shape(), x.shape()); + + let q = q + .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let k = k + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let mut v = v + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + + let q = self.apply_rotary_emb(&q, index_pos)?; + let mut k = self.apply_rotary_emb(&k, index_pos)?; + + let mut cache = self.cache.kvs.lock().unwrap(); + if let Some((cache_k, cache_v)) = &cache[block_idx] { + k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?; + v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?; + let k_seq_len = k.dims()[1]; + if k_seq_len > MAX_SEQ_LEN { + k = k + .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .contiguous()? + } + let v_seq_len = v.dims()[1]; + if v_seq_len > 2 * MAX_SEQ_LEN { + v = v + .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .contiguous()? + } + } + cache[block_idx] = Some((k.clone(), v.clone())); + + let k = self.repeat_kv(k)?; + let v = self.repeat_kv(v)?; + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; + let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = candle_nn::ops::softmax(&att, D::Minus1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let y = att.matmul(&v.contiguous()?)?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?; + let y = y.to_dtype(x_dtype)?; + let y = self.o_proj.forward(&y)?; + + Ok(y) + } + + fn repeat_kv(&self, x: Tensor) -> Result { + let n_rep = self.num_attention_heads / self.num_key_value_heads; + candle_transformers::utils::repeat_kv(x, n_rep) + } + + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { + let qkv_proj = TensorParallelColumnLinear::load_multi( + vb.clone(), + &["q_proj", "k_proj", "v_proj"], + comm.clone(), + )?; + let o_proj = TensorParallelRowLinear::load(vb.pp("o_proj"), comm.clone())?; + Ok(Self { + qkv_proj, + o_proj, + num_attention_heads: cfg.num_attention_heads / comm.world_size(), + num_key_value_heads: cfg.num_key_value_heads() / comm.world_size(), + head_dim: cfg.hidden_size / cfg.num_attention_heads, + cache: cache.clone(), + }) + } +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +struct Mlp { + c_fc1: TensorParallelColumnLinear, + c_fc2: TensorParallelColumnLinear, + c_proj: TensorParallelRowLinear, +} + +impl Mlp { + fn forward(&self, x: &Tensor) -> Result { + let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; + self.c_proj.forward(&x) + } + + fn load(vb: VarBuilder, _cfg: &Config, comm: Rc) -> Result { + let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?; + let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?; + let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?; + Ok(Self { + c_fc1, + c_fc2, + c_proj, + }) + } +} + +struct Block { + rms_1: RmsNorm, + attn: CausalSelfAttention, + rms_2: RmsNorm, + mlp: Mlp, +} + +fn rms_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get_with_hints(size, "weight", shard(0, 0, 1))?; + Ok(RmsNorm::new(weight, eps)) +} + +impl Block { + fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { + Self { + rms_1, + attn, + rms_2, + mlp, + } + } + + fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { + let residual = x; + let x = self.rms_1.forward(x)?; + let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; + let residual = &x; + let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; + Ok(x) + } + + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { + let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?; + let mlp = Mlp::load(vb.pp("mlp"), cfg, comm)?; + let input_layernorm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("input_layernorm"))?; + let post_attention_layernorm = + rms_norm(cfg.hidden_size, 1e-5, vb.pp("post_attention_layernorm"))?; + Ok(Self::new( + input_layernorm, + attn, + post_attention_layernorm, + mlp, + )) + } +} + +pub struct Llama { + wte: Embedding, + blocks: Vec, + ln_f: RmsNorm, + lm_head: Linear, +} + +impl Llama { + fn new(wte: Embedding, blocks: Vec, ln_f: RmsNorm, lm_head: Linear) -> Self { + Self { + wte, + blocks, + ln_f, + lm_head, + } + } + + pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result { + let (_b_sz, seq_len) = x.shape().dims2()?; + let mut x = self.wte.forward(x)?; + for (block_idx, block) in self.blocks.iter().enumerate() { + x = block.forward(&x, index_pos, block_idx)?; + } + let x = self.ln_f.forward(&x)?; + let x = x.i((.., seq_len - 1, ..))?; + let logits = self.lm_head.forward(&x)?; + logits.to_dtype(DType::F32) + } + + pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { + let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; + let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("model.norm"))?; + let blocks: Vec<_> = (0..cfg.num_hidden_layers) + .map(|i| { + Block::load( + vb.pp(&format!("model.layers.{i}")), + cache, + cfg, + comm.clone(), + ) + }) + .collect::>>()?; + Ok(Self::new(wte, blocks, norm, lm_head)) + } +} diff --git a/candle-examples/examples/llama_multinode/nccl_env_vars.sh b/candle-examples/examples/llama_multinode/nccl_env_vars.sh new file mode 100644 index 0000000000..1d83846da4 --- /dev/null +++ b/candle-examples/examples/llama_multinode/nccl_env_vars.sh @@ -0,0 +1,8 @@ +export NCCL_DEBUG=INFO +export NCCL_DEBUG_SUBSYS=ALL +export NCCL_SOCKET_IFNAME=ens5 +export NCCL_NET=Socket +export NCCL_NET_OFI_DISABLE=1 +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_NTHREADS=4 +export NCCL_NSOCKS_PERTHREAD=4 \ No newline at end of file diff --git a/candle-examples/examples/llama_multinode/nccl_id_distribution.rs b/candle-examples/examples/llama_multinode/nccl_id_distribution.rs new file mode 100644 index 0000000000..4d0a1f9c0d --- /dev/null +++ b/candle-examples/examples/llama_multinode/nccl_id_distribution.rs @@ -0,0 +1,54 @@ +use cudarc::nccl::safe::Id; +use std::convert::TryInto; +use std::io::{Read, Write}; +use std::net::{SocketAddr, TcpListener, TcpStream}; +use std::sync::Arc; + +pub async fn run_nccl_id_server(port: u16, nccl_id: Id, num_workers: usize) -> std::io::Result<()> { + let listener = TcpListener::bind(("0.0.0.0", port))?; + println!("NCCL ID Server listening on 0.0.0.0:{}", port); + + let nccl_id_bytes: &[i8; 128] = nccl_id.internal(); + let nccl_id_bytes = Arc::new(*nccl_id_bytes); + + let mut connected_workers = 0; + while connected_workers < num_workers { + match listener.accept() { + Ok((mut stream, addr)) => { + println!("Worker connected from: {}", addr); + let nccl_id_bytes = Arc::clone(&nccl_id_bytes); + + let bytes_to_send: Vec = nccl_id_bytes.iter().map(|&x| x as u8).collect(); + if let Err(e) = stream.write_all(&bytes_to_send) { + eprintln!("Error sending NCCL ID to worker {}: {:?}", addr, e); + } else { + connected_workers += 1; + println!( + "NCCL ID sent to worker {}. {}/{} workers connected.", + addr, connected_workers, num_workers + ); + } + } + Err(e) => { + eprintln!("Error accepting connection: {:?}", e); + } + } + } + + println!("NCCL ID sent to all {} workers", num_workers); + Ok(()) +} + +pub async fn get_nccl_id_from_server(addr: SocketAddr) -> std::io::Result { + let mut stream = TcpStream::connect(addr)?; + let mut buffer = [0u8; 128]; + stream.read_exact(&mut buffer)?; + + let internal: [i8; 128] = buffer + .iter() + .map(|&b| b as i8) + .collect::>() + .try_into() + .unwrap(); + Ok(Id::uninit(internal)) +}