Skip to content

Commit

Permalink
Cuda graph experiments.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Oct 3, 2024
1 parent 6faecaa commit 9076dee
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 1 deletion.
4 changes: 4 additions & 0 deletions candle-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,7 @@ harness = false
[[example]]
name = "metal_basics"
required-features = ["metal"]

[[example]]
name = "cuda_basics"
required-features = ["cuda"]
57 changes: 56 additions & 1 deletion candle-core/examples/cuda_basics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,63 @@ extern crate intel_mkl_src;
use anyhow::Result;
use candle_core::{Device, Tensor};

fn cuda_graph() -> Result<()> {
let device = Device::new_cuda_with_stream(0)?;
let cu_device = match &device {
Device::Cuda(dev) => dev,
_ => unreachable!(),
};
let cu_stream = cu_device.cu_stream();
{
// load_ptx cannot be called while capturing the stream so we need this to happen
// beforehand.
let x = Tensor::zeros(16, candle_core::DType::F32, &device)?;
let y = Tensor::zeros(16, candle_core::DType::F32, &device)?;
y.slice_set(&x, 0, 0)?;
device.synchronize()?;
}
unsafe {
cudarc::driver::sys::lib()
.cuStreamBeginCapture_v2(
*cu_stream,
cudarc::driver::sys::CUstreamCaptureMode_enum::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
)
.result()?
};
{
let x = Tensor::zeros(16, candle_core::DType::F32, &device)?;
let y = Tensor::zeros(16, candle_core::DType::F32, &device)?;
y.slice_set(&x, 0, 0)?;
// let y = x.affine(2., 1.)?;
}
let cu_graph = unsafe {
let mut cu_graph = std::mem::MaybeUninit::uninit();
cudarc::driver::sys::lib()
.cuStreamEndCapture(*cu_stream, cu_graph.as_mut_ptr())
.result()?;
cu_graph.assume_init()
};
let cu_graph_e = unsafe {
let mut cu_graph_e = std::mem::MaybeUninit::uninit();
cudarc::driver::sys::lib()
.cuGraphInstantiateWithFlags(cu_graph_e.as_mut_ptr(), cu_graph, 0)
.result()?;
cu_graph_e.assume_init()
};
for _i in 0..100 {
unsafe {
cudarc::driver::sys::lib()
.cuGraphLaunch(cu_graph_e, *cu_stream)
.result()?
}
}
Ok(())
}

fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
cuda_graph()?;
return Ok(());
let device = Device::new_cuda_with_stream(0)?;
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
.to_dtype(candle_core::DType::BF16)?;
candle_core::cuda::set_gemm_reduced_precision_f32(false);
Expand Down

0 comments on commit 9076dee

Please sign in to comment.