Skip to content

Commit

Permalink
More cuda graph attempts.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Oct 3, 2024
1 parent 9076dee commit b295685
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions candle-core/examples/cuda_basics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@ fn cuda_graph() -> Result<()> {
{
// load_ptx cannot be called while capturing the stream so we need this to happen
// beforehand.
let x = Tensor::zeros(16, candle_core::DType::F32, &device)?;
let y = Tensor::zeros(16, candle_core::DType::F32, &device)?;
y.slice_set(&x, 0, 0)?;
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let v = Tensor::zeros(4096, candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let _x = x.mul(&u)?.broadcast_add(&v)?;
device.synchronize()?;
}
unsafe {
Expand All @@ -31,10 +35,15 @@ fn cuda_graph() -> Result<()> {
.result()?
};
{
let x = Tensor::zeros(16, candle_core::DType::F32, &device)?;
let y = Tensor::zeros(16, candle_core::DType::F32, &device)?;
y.slice_set(&x, 0, 0)?;
// let y = x.affine(2., 1.)?;
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let v = Tensor::zeros(4096, candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
for _i in 0..1 {
x = x.mul(&u)?.broadcast_add(&v)?;
}
}
let cu_graph = unsafe {
let mut cu_graph = std::mem::MaybeUninit::uninit();
Expand Down

0 comments on commit b295685

Please sign in to comment.