From b08dbfede911833c7ed8cb5e88e86639f265496a Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 2 Oct 2024 23:09:18 +0200 Subject: [PATCH] Fix for cudnn bf16 conv2d. --- candle-core/src/cuda_backend/cudnn.rs | 11 ++++++----- candle-core/src/cuda_backend/mod.rs | 13 ++++++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/candle-core/src/cuda_backend/cudnn.rs b/candle-core/src/cuda_backend/cudnn.rs index d604863d3..f5b4db902 100644 --- a/candle-core/src/cuda_backend/cudnn.rs +++ b/candle-core/src/cuda_backend/cudnn.rs @@ -26,6 +26,7 @@ impl From for crate::Error { pub(crate) fn launch_conv2d< T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType, + Y: cudarc::cudnn::CudnnDataType, >( src: &CudaView, src_l: &crate::Layout, @@ -48,7 +49,7 @@ pub(crate) fn launch_conv2d< } c })?; - let conv = cudnn.create_conv2d::( + let conv = cudnn.create_conv2d::( /* pad */ [params.padding as i32, params.padding as i32], /* stride */ [params.stride as i32, params.stride as i32], /* dilation */ [params.dilation as i32, params.dilation as i32], @@ -62,18 +63,18 @@ pub(crate) fn launch_conv2d< ]; // Note that `src` already starts at the proper offset. let x = if src_l.is_contiguous() { - cudnn.create_4d_tensor( + cudnn.create_4d_tensor::( cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, x_shape, )? } else { let s = src_l.stride(); - cudnn.create_4d_tensor_ex( + cudnn.create_4d_tensor_ex::( x_shape, [s[0] as i32, s[1] as i32, s[2] as i32, s[3] as i32], )? }; - let w = cudnn.create_4d_filter( + let w = cudnn.create_4d_filter::( cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, [ params.c_out as i32, @@ -83,7 +84,7 @@ pub(crate) fn launch_conv2d< ], )?; let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32); - let y = cudnn.create_4d_tensor( + let y = cudnn.create_4d_tensor::( cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, [params.b_size as i32, params.c_out as i32, h_out, w_out], )?; diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 07bb1785d..f14e00d53 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -1522,7 +1522,7 @@ impl BackendStorage for CudaStorage { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::(dst_el) }.w()?; - crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) + crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::U8(out) } @@ -1530,7 +1530,10 @@ impl BackendStorage for CudaStorage { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::(dst_el) }.w()?; - crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) + // Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16" + // version. + // https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88 + crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::BF16(out) } @@ -1538,7 +1541,7 @@ impl BackendStorage for CudaStorage { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::(dst_el) }.w()?; - crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) + crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::F16(out) } @@ -1546,7 +1549,7 @@ impl BackendStorage for CudaStorage { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::(dst_el) }.w()?; - crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) + crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::F32(out) } @@ -1554,7 +1557,7 @@ impl BackendStorage for CudaStorage { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::(dst_el) }.w()?; - crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) + crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::F64(out) }