@@ -1107,14 +1107,19 @@ static void ggml_cuda_op_mul_mat_cublas(
1107
1107
const half alpha_f16 = 1 .0f ;
1108
1108
const half beta_f16 = 0 .0f ;
1109
1109
1110
+ cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1111
+ if (ggml_cuda_info ().devices [ctx.device ].cc == CC_CDNA) {
1112
+ cu_compute_type = CUBLAS_COMPUTE_32F;
1113
+ }
1114
+
1110
1115
CUBLAS_CHECK (cublasSetStream (ctx.cublas_handle (id), stream));
1111
1116
CUBLAS_CHECK (
1112
1117
cublasGemmEx (ctx.cublas_handle (id), CUBLAS_OP_T, CUBLAS_OP_N,
1113
1118
row_diff, src1_ncols, ne10,
1114
1119
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
1115
1120
src1_ptr, CUDA_R_16F, ne10,
1116
1121
&beta_f16, dst_f16.get (), CUDA_R_16F, ldc,
1117
- CUBLAS_COMPUTE_16F ,
1122
+ cu_compute_type ,
1118
1123
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1119
1124
1120
1125
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
@@ -1607,6 +1612,10 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1607
1612
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1608
1613
cudaDataType_t cu_data_type = CUDA_R_16F;
1609
1614
1615
+ if (ggml_cuda_info ().devices [ctx.device ].cc == CC_CDNA) {
1616
+ cu_compute_type = CUBLAS_COMPUTE_32F;
1617
+ }
1618
+
1610
1619
// dst strides
1611
1620
size_t nbd2 = dst->nb [2 ];
1612
1621
size_t nbd3 = dst->nb [3 ];
0 commit comments