Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TP sharding v2 #216

Merged
merged 8 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ byteorder = "1.4.3"
clap = { version = "4.2.4", features = ["derive"] }
# Re-enable this once 0.9.13 as been released as it would include the cublas-f16 changes
# cudarc = { version = "0.9.13", optional = true, features = ["f16"] }
cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", features = ["f16"] }
# TODO: Switch back to the official gemm implementation if we manage to upstream the changes.
cudarc = { git = "https://github.com/coreylowman/cudarc.git", features = ["f16", "nccl"] }
# TODO: Switch back to the official gemm implementation once the following are available.
# https://github.com/sarah-ek/gemm/pull/8.
# https://github.com/sarah-ek/gemm/pull/9.
gemm = { git = "https://github.com/LaurentMazare/gemm.git" }
hf-hub = "0.1.3"
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
Expand Down
7 changes: 7 additions & 0 deletions candle-core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ pub enum Error {
nth_shape: Shape,
},

#[error("Cannot divide tensor of shape {shape:?} equally along dim {dim} into {n_parts}")]
ShapeMismatchSplit {
shape: Shape,
dim: usize,
n_parts: usize,
},

#[error("{op} can only be performed on a single dimension")]
OnlySingleDimension { op: &'static str, dims: Vec<usize> },

Expand Down
44 changes: 33 additions & 11 deletions candle-core/src/safetensors.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{DType, Device, Error, Result, Tensor, WithDType};
use safetensors::tensor as st;
pub use safetensors::tensor::SafeTensors;
use safetensors::tensor::SafeTensors;
use std::borrow::Cow;

impl From<DType> for st::Dtype {
Expand Down Expand Up @@ -63,15 +63,15 @@ impl Tensor {
}
}

fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
let v = view.data();
fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
let size_in_bytes = T::DTYPE.size_in_bytes();
let elem_count = v.len() / size_in_bytes;
if (v.as_ptr() as usize) % size_in_bytes == 0 {
let elem_count = data.len() / size_in_bytes;
if (data.as_ptr() as usize) % size_in_bytes == 0 {
// SAFETY This is safe because we just checked that this
// was correctly aligned.
let data: &[T] = unsafe { std::slice::from_raw_parts(v.as_ptr() as *const T, elem_count) };
Tensor::from_slice(data, view.shape(), device)
let data: &[T] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
Tensor::from_slice(data, shape, device)
} else {
// XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
// Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
Expand All @@ -81,13 +81,17 @@ fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<
// We're downgrading the `c` pointer from T to u8, which removes alignment
// constraints.
unsafe {
std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len());
std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
c.set_len(elem_count)
}
Tensor::from_slice(&c, view.shape(), device)
Tensor::from_slice(&c, shape, device)
}
}

fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
convert_slice::<T>(view.data(), view.shape(), device)
}

fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
let size_in_bytes = T::DTYPE.size_in_bytes();
let length = vs.len() * size_in_bytes;
Expand All @@ -112,7 +116,25 @@ impl<'a> Load for st::TensorView<'a> {
}
}

pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
impl Tensor {
pub fn from_raw_buffer(
data: &[u8],
dtype: DType,
shape: &[usize],
device: &Device,
) -> Result<Self> {
match dtype {
DType::U8 => convert_slice::<u8>(data, shape, device),
DType::U32 => convert_slice::<u32>(data, shape, device),
DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
DType::F16 => convert_slice::<half::f16>(data, shape, device),
DType::F32 => convert_slice::<f32>(data, shape, device),
DType::F64 => convert_slice::<f64>(data, shape, device),
}
}
}

fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
match view.dtype() {
st::Dtype::U8 => convert_::<u8>(view, device),
st::Dtype::U32 => convert_::<u8>(view, device),
Expand All @@ -124,7 +146,7 @@ pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
}
}

pub fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
// TODO: This makes an unnecessary copy when the tensor is on the cpu.
let tensor = tensor.flatten_all()?;
match tensor.dtype() {
Expand Down
7 changes: 7 additions & 0 deletions candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ serde = { workspace = true }
serde_json = { workspace = true }
num-traits = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }

[dev-dependencies]
anyhow = { workspace = true }
Expand All @@ -40,3 +42,8 @@ default = []
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["dep:cudarc", "dep:half"]

[[example]]
name = "llama_multiprocess"
required-features = ["cuda", "nccl"]
Loading
Loading