Skip to content

Commit 3a1d670

Browse files
committed
Allow compiling ggml-cuda without mmq or flash attention
1 parent badbdfc commit 3a1d670

File tree

6 files changed

+45
-13
lines changed

6 files changed

+45
-13
lines changed

ggml/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
149149
"ggml: max. batch size for using peer access")
150150
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
151151
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
152+
option(GGML_CUDA_FA "ggml: compile with FlashAttention" ON)
152153
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
153154
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
154155

ggml/src/ggml-cuda/CMakeLists.txt

+24-13
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,35 @@ if (CUDAToolkit_FOUND)
2828
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
2929

3030
file(GLOB GGML_SOURCES_CUDA "*.cu")
31-
file(GLOB SRCS "template-instances/fattn-wmma*.cu")
32-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
33-
file(GLOB SRCS "template-instances/mmq*.cu")
34-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
35-
36-
if (GGML_CUDA_FA_ALL_QUANTS)
37-
file(GLOB SRCS "template-instances/fattn-vec*.cu")
31+
if (GGML_CUDA_FA)
32+
file(GLOB SRCS "template-instances/fattn-wmma*.cu")
3833
list(APPEND GGML_SOURCES_CUDA ${SRCS})
39-
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
4034
else()
41-
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
42-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
43-
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
44-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
45-
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
35+
list(FILTER GGML_SOURCES_CUDA EXCLUDE REGEX ".*fattn.*")
36+
list(FILTER GGML_HEADERS_CUDA EXCLUDE REGEX ".*fattn.*")
37+
# message(FATAL_ERROR ${GGML_SOURCES_CUDA})
38+
endif()
39+
if (NOT GGML_CUDA_FORCE_CUBLAS)
40+
file(GLOB SRCS "template-instances/mmq*.cu")
4641
list(APPEND GGML_SOURCES_CUDA ${SRCS})
4742
endif()
4843

44+
if (GGML_CUDA_FA)
45+
add_compile_definitions(GGML_CUDA_FA)
46+
if (GGML_CUDA_FA_ALL_QUANTS)
47+
file(GLOB SRCS "template-instances/fattn-vec*.cu")
48+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
49+
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
50+
else()
51+
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
52+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
53+
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
54+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
55+
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
56+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
57+
endif()
58+
endif()
59+
4960
ggml_add_backend_library(ggml-cuda
5061
${GGML_HEADERS_CUDA}
5162
${GGML_SOURCES_CUDA}

ggml/src/ggml-cuda/common.cuh

+4
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@ typedef float2 dfloat2;
151151
#define FLASH_ATTN_AVAILABLE
152152
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
153153

154+
#if !defined(GGML_CUDA_FA)
155+
#undef FLASH_ATTN_AVAILABLE
156+
#endif
157+
154158
static constexpr bool fast_fp16_available(const int cc) {
155159
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
156160
}

ggml/src/ggml-cuda/ggml-cuda.cu

+6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
#include "ggml-cuda/cpy.cuh"
1717
#include "ggml-cuda/cross-entropy-loss.cuh"
1818
#include "ggml-cuda/diagmask.cuh"
19+
#ifdef FLASH_ATTN_AVAILABLE
1920
#include "ggml-cuda/fattn.cuh"
21+
#endif
2022
#include "ggml-cuda/getrows.cuh"
2123
#include "ggml-cuda/im2col.cuh"
2224
#include "ggml-cuda/mmq.cuh"
@@ -2160,8 +2162,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
21602162
ggml_cuda_op_argsort(ctx, dst);
21612163
break;
21622164
case GGML_OP_FLASH_ATTN_EXT:
2165+
#ifdef FLASH_ATTN_AVAILABLE
21632166
ggml_cuda_flash_attn_ext(ctx, dst);
21642167
break;
2168+
#else
2169+
return false;
2170+
#endif
21652171
case GGML_OP_CROSS_ENTROPY_LOSS:
21662172
ggml_cuda_cross_entropy_loss(ctx, dst);
21672173
break;

ggml/src/ggml-cuda/mmq.cu

+8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
#include "mmq.cuh"
22

3+
#ifdef GGML_CUDA_FORCE_CUBLAS
4+
void ggml_cuda_op_mul_mat_q(
5+
ggml_backend_cuda_context &,
6+
const ggml_tensor *, const ggml_tensor *, ggml_tensor *, const char *, const float *,
7+
const char *, float *, const int64_t, const int64_t, const int64_t,
8+
const int64_t, cudaStream_t) {}
9+
#else
310
void ggml_cuda_op_mul_mat_q(
411
ggml_backend_cuda_context & ctx,
512
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
@@ -94,6 +101,7 @@ void ggml_cuda_op_mul_mat_q(
94101
GGML_UNUSED(dst);
95102
GGML_UNUSED(src1_ddf_i);
96103
}
104+
#endif // GGML_CUDA_FORCE_CUBLAS
97105

98106
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
99107
#ifdef GGML_CUDA_FORCE_CUBLAS

ggml/src/ggml-cuda/mmq.cuh

+2
Original file line numberDiff line numberDiff line change
@@ -2906,6 +2906,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
29062906
#define DECL_MMQ_CASE(type) \
29072907
template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
29082908

2909+
#if !defined(GGML_CUDA_FORCE_CUBLAS)
29092910
extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
29102911
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
29112912
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
@@ -2924,6 +2925,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
29242925
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
29252926
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
29262927
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
2928+
#endif
29272929

29282930
// -------------------------------------------------------------------------------------------------------------------------
29292931

0 commit comments

Comments
 (0)