Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow compiling cuda without mmq and flash attention #11190

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
"ggml: max. batch size for using peer access")
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
option(GGML_CUDA_FA "ggml: compile with FlashAttention" ON)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
option(GGML_CUDA_FA "ggml: compile with FlashAttention" ON)
option(GGML_CUDA_FA "ggml: compile ggml FlashAttention kernels" ON)

option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})

Expand Down
37 changes: 24 additions & 13 deletions ggml/src/ggml-cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,35 @@ if (CUDAToolkit_FOUND)
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")

file(GLOB GGML_SOURCES_CUDA "*.cu")
file(GLOB SRCS "template-instances/fattn-wmma*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})

if (GGML_CUDA_FA_ALL_QUANTS)
file(GLOB SRCS "template-instances/fattn-vec*.cu")
if (GGML_CUDA_FA)
file(GLOB SRCS "template-instances/fattn-wmma*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
else()
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
list(FILTER GGML_SOURCES_CUDA EXCLUDE REGEX ".*fattn.*")
list(FILTER GGML_HEADERS_CUDA EXCLUDE REGEX ".*fattn.*")
# message(FATAL_ERROR ${GGML_SOURCES_CUDA})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to remove?

endif()
if (NOT GGML_CUDA_FORCE_CUBLAS)
file(GLOB SRCS "template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
endif()

if (GGML_CUDA_FA)
add_compile_definitions(GGML_CUDA_FA)
if (GGML_CUDA_FA_ALL_QUANTS)
file(GLOB SRCS "template-instances/fattn-vec*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
else()
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
endif()
endif()

ggml_add_backend_library(ggml-cuda
${GGML_HEADERS_CUDA}
${GGML_SOURCES_CUDA}
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ typedef float2 dfloat2;
#define FLASH_ATTN_AVAILABLE
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)

#if !defined(GGML_CUDA_FA)
#undef FLASH_ATTN_AVAILABLE
#endif
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#endif
#endif // !defined(GGML_CUDA_FA)


static constexpr bool fast_fp16_available(const int cc) {
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
}
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
#include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/cross-entropy-loss.cuh"
#include "ggml-cuda/diagmask.cuh"
#ifdef FLASH_ATTN_AVAILABLE
#include "ggml-cuda/fattn.cuh"
#endif
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#endif
#endif // FLASH_ATTN_AVAILABLE

#include "ggml-cuda/getrows.cuh"
#include "ggml-cuda/im2col.cuh"
#include "ggml-cuda/mmq.cuh"
Expand Down Expand Up @@ -2160,8 +2162,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_argsort(ctx, dst);
break;
case GGML_OP_FLASH_ATTN_EXT:
#ifdef FLASH_ATTN_AVAILABLE
ggml_cuda_flash_attn_ext(ctx, dst);
break;
#else
return false;
#endif
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#endif
#endif // FLASH_ATTN_AVAILABLE

case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;
Expand Down
8 changes: 8 additions & 0 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
#include "mmq.cuh"

#ifdef GGML_CUDA_FORCE_CUBLAS
void ggml_cuda_op_mul_mat_q(
ggml_backend_cuda_context &,
const ggml_tensor *, const ggml_tensor *, ggml_tensor *, const char *, const float *,
const char *, float *, const int64_t, const int64_t, const int64_t,
const int64_t, cudaStream_t) {}
#else
Comment on lines +3 to +9
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add GGML_ABORT("CUDA was compiled without MMQ support") to the function instead.

void ggml_cuda_op_mul_mat_q(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
Expand Down Expand Up @@ -94,6 +101,7 @@ void ggml_cuda_op_mul_mat_q(
GGML_UNUSED(dst);
GGML_UNUSED(src1_ddf_i);
}
#endif // GGML_CUDA_FORCE_CUBLAS

bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
#ifdef GGML_CUDA_FORCE_CUBLAS
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2906,6 +2906,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
#define DECL_MMQ_CASE(type) \
template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \

#if !defined(GGML_CUDA_FORCE_CUBLAS)
extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
Expand All @@ -2924,6 +2925,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
#endif
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#endif
#endif // !defined(GGML_CUDA_FORCE_CUBLAS)


// -------------------------------------------------------------------------------------------------------------------------

Expand Down
Loading