Skip to content

Commit b2be9ca

Browse files
committed
Return error instead of panicking in implicit GEMM
1 parent 6c07340 commit b2be9ca

File tree

2 files changed

+58
-23
lines changed

2 files changed

+58
-23
lines changed

crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs

+39-22
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use cmma::{Matrix, MatrixIdent, MatrixLayout};
66
use cubecl::{
77
cube,
88
ir::{Elem, FloatKind},
9+
linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError},
910
prelude::*,
1011
Compiler, CubeCount, CubeDim, Feature,
1112
};
@@ -66,7 +67,7 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement>(
6667

6768
let padded_batch_size = padded_batch_size(batch_size, out_h, out_w);
6869

69-
if !can_do_implicit_gemm::<R, F>(
70+
check_availability::<R, F>(
7071
batch_size,
7172
in_channels,
7273
out_channels,
@@ -75,15 +76,7 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement>(
7576
out_h,
7677
out_w,
7778
&input.client,
78-
) {
79-
panic!(
80-
"Requirements for implicit GEMM not met:
81-
- CMMA must be available
82-
- `groups` must be 1
83-
- subcube size must be non-variable (might not hold on Intel)
84-
"
85-
);
86-
}
79+
)?;
8780

8881
// If input is contiguous NCHW, use custom transpose kernel
8982
let input = match input.is_contiguous() {
@@ -643,7 +636,7 @@ fn load_weight_tile<F: Float, FMat: Float>(
643636
}
644637

645638
#[allow(clippy::too_many_arguments)]
646-
pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
639+
pub(crate) fn check_availability<R: JitRuntime, E: FloatElement>(
647640
batch_size: usize,
648641
in_channels: usize,
649642
out_channels: usize,
@@ -652,7 +645,7 @@ pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
652645
out_h: usize,
653646
out_w: usize,
654647
client: &ComputeClient<R::Server, R::Channel>,
655-
) -> bool {
648+
) -> Result<(), ConvLaunchError> {
656649
let cmma_k = match (
657650
E::as_elem_native_unchecked(),
658651
client
@@ -672,19 +665,43 @@ pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
672665
let gemm_n = out_channels;
673666
let gemm_k = in_channels * kernel_h * kernel_w;
674667

675-
let size = find_cmma_size::<R, E>(client, gemm_m as u32, gemm_k as u32, gemm_n as u32);
676-
677-
if let Some((cmma_m, cmma_k, cmma_n)) = size {
678-
let warps_per_cube = 8;
668+
let (cmma_m, cmma_n, cmma_k) =
669+
find_cmma_size::<R, E>(client, gemm_m as u32, gemm_k as u32, gemm_n as u32).ok_or_else(
670+
|| {
671+
ConvLaunchError::Matmul(MatmulLaunchError::Unavailable(
672+
MatmulAvailabilityError::CmmaInstructionUnavailable {
673+
input: E::as_elem_native_unchecked(),
674+
output: E::as_elem_native_unchecked(),
675+
m: 16,
676+
n: 16,
677+
k: cmma_k as u32,
678+
},
679+
))
680+
},
681+
)?;
682+
683+
let warps_per_cube = 8;
684+
685+
let smem_size = ((cmma_m + cmma_n) * cmma_k * warps_per_cube) as usize * size_of::<f16>();
686+
if <R::Compiler as Compiler>::max_shared_memory_size() < smem_size {
687+
return Err(ConvLaunchError::Matmul(MatmulLaunchError::InvalidConfig(
688+
Box::new("Not enough shared memory"),
689+
)));
690+
}
679691

680-
let smem_size = ((cmma_m + cmma_n) * cmma_k * warps_per_cube) as usize * size_of::<f16>();
681-
let topology = client.properties().hardware_properties();
682-
let not_intel = topology.plane_size_min >= 32;
692+
let topology = client.properties().hardware_properties();
693+
if topology.plane_size_min < 32 {
694+
return Err(ConvLaunchError::Matmul(MatmulLaunchError::Unavailable(
695+
MatmulAvailabilityError::PlaneDimUnsupported {
696+
plane_dim: topology.plane_size_min,
697+
},
698+
)));
699+
}
683700

684-
<R::Compiler as Compiler>::max_shared_memory_size() >= smem_size && groups == 1 && not_intel
685-
} else {
686-
false
701+
if groups != 1 {
702+
return Err(ConvLaunchError::Groups(groups));
687703
}
704+
Ok(())
688705
}
689706

690707
fn padded_k(

crates/burn-jit/src/kernel/conv/error.rs

+19-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,29 @@
1+
use core::fmt::Debug;
12
use cubecl::{linalg::matmul::kernels::MatmulLaunchError, tune::AutotuneError};
23

3-
#[derive(Debug)]
44
pub enum ConvLaunchError {
55
Matmul(MatmulLaunchError),
6+
Groups(usize),
67
Unknown,
78
}
89

10+
impl Debug for ConvLaunchError {
11+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
12+
match self {
13+
ConvLaunchError::Matmul(err) => {
14+
write!(f, "{err:?}")
15+
}
16+
ConvLaunchError::Groups(groups) => {
17+
writeln!(
18+
f,
19+
"Unable to launch matmul because groups must be one, is actually {groups}",
20+
)
21+
}
22+
ConvLaunchError::Unknown => write!(f, "Unknown"),
23+
}
24+
}
25+
}
26+
927
impl From<MatmulLaunchError> for ConvLaunchError {
1028
fn from(value: MatmulLaunchError) -> Self {
1129
Self::Matmul(value)

0 commit comments

Comments
 (0)