Skip to content

Commit

Permalink
Fix the fast bf16 gemm cublas kernels. (huggingface#2274)
Browse files Browse the repository at this point in the history
* Use flash-attn in gemma.

* Fix for the fast bf16 cublas gemm.

* Fix some clippy lints.

* Fix another lint.

* Proper clippy fix.
  • Loading branch information
LaurentMazare authored and EricLBuehler committed Jun 29, 2024
1 parent 5b04d96 commit f7095bb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
5 changes: 4 additions & 1 deletion candle-core/examples/cuda_basics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ use candle_core::{Device, Tensor};

fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?;
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
.to_dtype(candle_core::DType::BF16)?;
candle_core::cuda::set_gemm_reduced_precision_f32(false);
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
Expand All @@ -19,6 +21,7 @@ fn main() -> Result<()> {
println!("fp32: {:?}", start_time.elapsed());
drop(_x1);
candle_core::cuda::set_gemm_reduced_precision_f32(true);
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
Expand Down
8 changes: 3 additions & 5 deletions candle-core/src/cuda_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2035,15 +2035,13 @@ unsafe fn gemm_strided_batched_bf16(

let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
let beta_f32: f32 = cfg.gemm.beta.to_f32();
let alpha = f16::from_f32(alpha_f32);
let beta = f16::from_f32(beta_f32);
// The type for alpha and beta depends on the computeType.
// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex
let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() {
(
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
(&alpha) as *const f16 as *const _,
(&beta) as *const f16 as *const _,
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF,
(&alpha_f32) as *const f32 as *const _,
(&beta_f32) as *const f32 as *const _,
)
} else {
(
Expand Down

0 comments on commit f7095bb

Please sign in to comment.