@@ -6,6 +6,7 @@ use cmma::{Matrix, MatrixIdent, MatrixLayout};
6
6
use cubecl:: {
7
7
cube,
8
8
ir:: { Elem , FloatKind } ,
9
+ linalg:: matmul:: kernels:: { MatmulAvailabilityError , MatmulLaunchError } ,
9
10
prelude:: * ,
10
11
Compiler , CubeCount , CubeDim , Feature ,
11
12
} ;
@@ -66,7 +67,7 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement>(
66
67
67
68
let padded_batch_size = padded_batch_size ( batch_size, out_h, out_w) ;
68
69
69
- if ! can_do_implicit_gemm :: < R , F > (
70
+ check_availability :: < R , F > (
70
71
batch_size,
71
72
in_channels,
72
73
out_channels,
@@ -75,15 +76,7 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement>(
75
76
out_h,
76
77
out_w,
77
78
& input. client ,
78
- ) {
79
- panic ! (
80
- "Requirements for implicit GEMM not met:
81
- - CMMA must be available
82
- - `groups` must be 1
83
- - subcube size must be non-variable (might not hold on Intel)
84
- "
85
- ) ;
86
- }
79
+ ) ?;
87
80
88
81
// If input is contiguous NCHW, use custom transpose kernel
89
82
let input = match input. is_contiguous ( ) {
@@ -643,7 +636,7 @@ fn load_weight_tile<F: Float, FMat: Float>(
643
636
}
644
637
645
638
#[ allow( clippy:: too_many_arguments) ]
646
- pub ( crate ) fn can_do_implicit_gemm < R : JitRuntime , E : FloatElement > (
639
+ pub ( crate ) fn check_availability < R : JitRuntime , E : FloatElement > (
647
640
batch_size : usize ,
648
641
in_channels : usize ,
649
642
out_channels : usize ,
@@ -652,7 +645,7 @@ pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
652
645
out_h : usize ,
653
646
out_w : usize ,
654
647
client : & ComputeClient < R :: Server , R :: Channel > ,
655
- ) -> bool {
648
+ ) -> Result < ( ) , ConvLaunchError > {
656
649
let cmma_k = match (
657
650
E :: as_elem_native_unchecked ( ) ,
658
651
client
@@ -672,19 +665,43 @@ pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
672
665
let gemm_n = out_channels;
673
666
let gemm_k = in_channels * kernel_h * kernel_w;
674
667
675
- let size = find_cmma_size :: < R , E > ( client, gemm_m as u32 , gemm_k as u32 , gemm_n as u32 ) ;
676
-
677
- if let Some ( ( cmma_m, cmma_k, cmma_n) ) = size {
678
- let warps_per_cube = 8 ;
668
+ let ( cmma_m, cmma_n, cmma_k) =
669
+ find_cmma_size :: < R , E > ( client, gemm_m as u32 , gemm_k as u32 , gemm_n as u32 ) . ok_or_else (
670
+ || {
671
+ ConvLaunchError :: Matmul ( MatmulLaunchError :: Unavailable (
672
+ MatmulAvailabilityError :: CmmaInstructionUnavailable {
673
+ input : E :: as_elem_native_unchecked ( ) ,
674
+ output : E :: as_elem_native_unchecked ( ) ,
675
+ m : 16 ,
676
+ n : 16 ,
677
+ k : cmma_k as u32 ,
678
+ } ,
679
+ ) )
680
+ } ,
681
+ ) ?;
682
+
683
+ let warps_per_cube = 8 ;
684
+
685
+ let smem_size = ( ( cmma_m + cmma_n) * cmma_k * warps_per_cube) as usize * size_of :: < f16 > ( ) ;
686
+ if <R :: Compiler as Compiler >:: max_shared_memory_size ( ) < smem_size {
687
+ return Err ( ConvLaunchError :: Matmul ( MatmulLaunchError :: InvalidConfig (
688
+ Box :: new ( "Not enough shared memory" ) ,
689
+ ) ) ) ;
690
+ }
679
691
680
- let smem_size = ( ( cmma_m + cmma_n) * cmma_k * warps_per_cube) as usize * size_of :: < f16 > ( ) ;
681
- let topology = client. properties ( ) . hardware_properties ( ) ;
682
- let not_intel = topology. plane_size_min >= 32 ;
692
+ let topology = client. properties ( ) . hardware_properties ( ) ;
693
+ if topology. plane_size_min < 32 {
694
+ return Err ( ConvLaunchError :: Matmul ( MatmulLaunchError :: Unavailable (
695
+ MatmulAvailabilityError :: PlaneDimUnsupported {
696
+ plane_dim : topology. plane_size_min ,
697
+ } ,
698
+ ) ) ) ;
699
+ }
683
700
684
- <R :: Compiler as Compiler >:: max_shared_memory_size ( ) >= smem_size && groups == 1 && not_intel
685
- } else {
686
- false
701
+ if groups != 1 {
702
+ return Err ( ConvLaunchError :: Groups ( groups) ) ;
687
703
}
704
+ Ok ( ( ) )
688
705
}
689
706
690
707
fn padded_k (
0 commit comments