Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TP sharding v2 #216

Merged
merged 8 commits into from
Jul 28, 2023
Merged

TP sharding v2 #216

merged 8 commits into from
Jul 28, 2023

Conversation

Narsil
Copy link
Collaborator

@Narsil Narsil commented Jul 21, 2023

Second update version of TP sharding.

  • This WAS NOT tested for accuracy (beacuse I didn't implement all_reduce)
  • Currently waiting for a simple API to get the cuda storage so I can call all_reduce directly and recreate a tensor from that. (That way I don't just pub everything)

I tried to keep the modifications minimal in var_builder for now as this is not the purpose of this PR.

Comment on lines +30 to +74
struct AllReduce {
comm: Rc<Comm>,
}

/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
/// But for this example purposes, this will work
unsafe impl Sync for AllReduce {}
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
/// But for this example purposes, this will work
unsafe impl Send for AllReduce {}

impl CustomOp1 for AllReduce {
fn name(&self) -> &'static str {
"allreduce"
}

fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
todo!("implement allreduce for cpu is not necessary for single node");
}

#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
s: &candle::CudaStorage,
l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::WrapErr;
let elem_count = l.shape().elem_count();
let dev = s.device().clone();
let s = s.as_cuda_slice::<f16>()?;
// let s = match l.contiguous_offsets() {
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
// Some((o1, o2)) => s.slice(o1..o2),
// };
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
Ok((dst, l.shape().clone()))
}
}

fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
x.custom_op1(AllReduce { comm: comm.clone() })
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the core of the new thing.

  • TP Row/Col

candle-nn/src/var_builder.rs Outdated Show resolved Hide resolved
candle-nn/src/var_builder.rs Outdated Show resolved Hide resolved
candle-nn/src/var_builder.rs Show resolved Hide resolved
candle-nn/src/var_builder.rs Show resolved Hide resolved
@Narsil Narsil merged commit 4f260ef into main Jul 28, 2023
10 checks passed
@LaurentMazare LaurentMazare deleted the llama_multiprocess2 branch August 15, 2023 20:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants