Skip to content

Commit 3ad5451

Browse files
authored
Add some minimal optimizations for CDNA (#10498)
* Add some minimal optimizations for CDNA * ggml_cuda: set launch bounds also for GCN as it helps there too
1 parent 46c69e0 commit 3ad5451

File tree

6 files changed

+36
-8
lines changed

6 files changed

+36
-8
lines changed

ggml/src/ggml-cuda/common.cuh

+14-3
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,20 @@
4747
#define CC_TURING 750
4848
#define CC_AMPERE 800
4949
#define CC_OFFSET_AMD 1000000
50-
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
51-
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
52-
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
50+
51+
// GCN/CNDA, wave size is 64
52+
#define CC_GCN4 (CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16
53+
#define CC_VEGA (CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue
54+
#define CC_VEGA20 (CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a
55+
#define CC_CDNA (CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers
56+
#define CC_CDNA2 (CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing
57+
#define CC_CDNA3 (CC_OFFSET_AMD + 942) // MI300
58+
59+
// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
60+
#define CC_RDNA1 (CC_OFFSET_AMD + 1010) // RX 5000
61+
#define CC_RDNA2 (CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
62+
#define CC_RDNA3 (CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA
63+
5364
#define CC_QY1 210
5465
#define CC_QY2 220
5566

ggml/src/ggml-cuda/ggml-cuda.cu

+10-1
Original file line numberDiff line numberDiff line change
@@ -1107,14 +1107,19 @@ static void ggml_cuda_op_mul_mat_cublas(
11071107
const half alpha_f16 = 1.0f;
11081108
const half beta_f16 = 0.0f;
11091109

1110+
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1111+
if (ggml_cuda_info().devices[ctx.device].cc == CC_CDNA) {
1112+
cu_compute_type = CUBLAS_COMPUTE_32F;
1113+
}
1114+
11101115
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
11111116
CUBLAS_CHECK(
11121117
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
11131118
row_diff, src1_ncols, ne10,
11141119
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
11151120
src1_ptr, CUDA_R_16F, ne10,
11161121
&beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
1117-
CUBLAS_COMPUTE_16F,
1122+
cu_compute_type,
11181123
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
11191124

11201125
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
@@ -1607,6 +1612,10 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
16071612
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
16081613
cudaDataType_t cu_data_type = CUDA_R_16F;
16091614

1615+
if (ggml_cuda_info().devices[ctx.device].cc == CC_CDNA) {
1616+
cu_compute_type = CUBLAS_COMPUTE_32F;
1617+
}
1618+
16101619
// dst strides
16111620
size_t nbd2 = dst->nb[2];
16121621
size_t nbd3 = dst->nb[3];

ggml/src/ggml-cuda/mmq.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,5 +148,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
148148
return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
149149
}
150150

151-
return cc < CC_RDNA3 || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
151+
return (cc < CC_RDNA3 && cc != CC_CDNA && cc != CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
152152
}

ggml/src/ggml-cuda/mmq.cuh

+2-2
Original file line numberDiff line numberDiff line change
@@ -2570,9 +2570,9 @@ static __device__ void mul_mat_q_process_tile(
25702570

25712571
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
25722572
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
2573-
#if defined(RDNA3) || defined(RDNA2)
2573+
#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
25742574
__launch_bounds__(WARP_SIZE*nwarps, 2)
2575-
#endif // defined(RDNA3) || defined(RDNA2)
2575+
#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
25762576
#else
25772577
#if __CUDA_ARCH__ >= CC_VOLTA
25782578
__launch_bounds__(WARP_SIZE*nwarps, 1)

ggml/src/ggml-cuda/mmvq.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ static void mul_mat_vec_q_cuda(
142142
int64_t nwarps = 1;
143143
int64_t rows_per_cuda_block = 1;
144144

145-
if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
145+
if (ggml_cuda_info().devices[id].cc < CC_CDNA || ggml_cuda_info().devices[id].cc == CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA
146146
switch(ncols_y) {
147147
case 1:
148148
nwarps = 4;

ggml/src/ggml-cuda/vendors/hip.h

+8
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,14 @@
9595

9696
#define __CUDA_ARCH__ 1300
9797

98+
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
99+
#define GCN
100+
#endif
101+
102+
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
103+
#define CDNA
104+
#endif
105+
98106
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
99107
defined(__gfx1150__) || defined(__gfx1151__)
100108
#define RDNA3

0 commit comments

Comments
 (0)