-
Notifications
You must be signed in to change notification settings - Fork 10.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow compiling cuda without mmq and flash attention #11190
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,24 +28,35 @@ if (CUDAToolkit_FOUND) | |
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") | ||
|
||
file(GLOB GGML_SOURCES_CUDA "*.cu") | ||
file(GLOB SRCS "template-instances/fattn-wmma*.cu") | ||
list(APPEND GGML_SOURCES_CUDA ${SRCS}) | ||
file(GLOB SRCS "template-instances/mmq*.cu") | ||
list(APPEND GGML_SOURCES_CUDA ${SRCS}) | ||
|
||
if (GGML_CUDA_FA_ALL_QUANTS) | ||
file(GLOB SRCS "template-instances/fattn-vec*.cu") | ||
if (GGML_CUDA_FA) | ||
file(GLOB SRCS "template-instances/fattn-wmma*.cu") | ||
list(APPEND GGML_SOURCES_CUDA ${SRCS}) | ||
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) | ||
else() | ||
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu") | ||
list(APPEND GGML_SOURCES_CUDA ${SRCS}) | ||
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu") | ||
list(APPEND GGML_SOURCES_CUDA ${SRCS}) | ||
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu") | ||
list(FILTER GGML_SOURCES_CUDA EXCLUDE REGEX ".*fattn.*") | ||
list(FILTER GGML_HEADERS_CUDA EXCLUDE REGEX ".*fattn.*") | ||
# message(FATAL_ERROR ${GGML_SOURCES_CUDA}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Forgot to remove? |
||
endif() | ||
if (NOT GGML_CUDA_FORCE_CUBLAS) | ||
file(GLOB SRCS "template-instances/mmq*.cu") | ||
list(APPEND GGML_SOURCES_CUDA ${SRCS}) | ||
endif() | ||
|
||
if (GGML_CUDA_FA) | ||
add_compile_definitions(GGML_CUDA_FA) | ||
if (GGML_CUDA_FA_ALL_QUANTS) | ||
file(GLOB SRCS "template-instances/fattn-vec*.cu") | ||
list(APPEND GGML_SOURCES_CUDA ${SRCS}) | ||
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) | ||
else() | ||
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu") | ||
list(APPEND GGML_SOURCES_CUDA ${SRCS}) | ||
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu") | ||
list(APPEND GGML_SOURCES_CUDA ${SRCS}) | ||
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu") | ||
list(APPEND GGML_SOURCES_CUDA ${SRCS}) | ||
endif() | ||
endif() | ||
|
||
ggml_add_backend_library(ggml-cuda | ||
${GGML_HEADERS_CUDA} | ||
${GGML_SOURCES_CUDA} | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -151,6 +151,10 @@ typedef float2 dfloat2; | |||||
#define FLASH_ATTN_AVAILABLE | ||||||
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) | ||||||
|
||||||
#if !defined(GGML_CUDA_FA) | ||||||
#undef FLASH_ATTN_AVAILABLE | ||||||
#endif | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
static constexpr bool fast_fp16_available(const int cc) { | ||||||
return cc >= GGML_CUDA_CC_PASCAL && cc != 610; | ||||||
} | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -16,7 +16,9 @@ | |||||
#include "ggml-cuda/cpy.cuh" | ||||||
#include "ggml-cuda/cross-entropy-loss.cuh" | ||||||
#include "ggml-cuda/diagmask.cuh" | ||||||
#ifdef FLASH_ATTN_AVAILABLE | ||||||
#include "ggml-cuda/fattn.cuh" | ||||||
#endif | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
#include "ggml-cuda/getrows.cuh" | ||||||
#include "ggml-cuda/im2col.cuh" | ||||||
#include "ggml-cuda/mmq.cuh" | ||||||
|
@@ -2160,8 +2162,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg | |||||
ggml_cuda_op_argsort(ctx, dst); | ||||||
break; | ||||||
case GGML_OP_FLASH_ATTN_EXT: | ||||||
#ifdef FLASH_ATTN_AVAILABLE | ||||||
ggml_cuda_flash_attn_ext(ctx, dst); | ||||||
break; | ||||||
#else | ||||||
return false; | ||||||
#endif | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS: | ||||||
ggml_cuda_cross_entropy_loss(ctx, dst); | ||||||
break; | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,12 @@ | ||
#include "mmq.cuh" | ||
|
||
#ifdef GGML_CUDA_FORCE_CUBLAS | ||
void ggml_cuda_op_mul_mat_q( | ||
ggml_backend_cuda_context &, | ||
const ggml_tensor *, const ggml_tensor *, ggml_tensor *, const char *, const float *, | ||
const char *, float *, const int64_t, const int64_t, const int64_t, | ||
const int64_t, cudaStream_t) {} | ||
#else | ||
Comment on lines
+3
to
+9
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add |
||
void ggml_cuda_op_mul_mat_q( | ||
ggml_backend_cuda_context & ctx, | ||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, | ||
|
@@ -94,6 +101,7 @@ void ggml_cuda_op_mul_mat_q( | |
GGML_UNUSED(dst); | ||
GGML_UNUSED(src1_ddf_i); | ||
} | ||
#endif // GGML_CUDA_FORCE_CUBLAS | ||
|
||
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { | ||
#ifdef GGML_CUDA_FORCE_CUBLAS | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2906,6 +2906,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda | |||||
#define DECL_MMQ_CASE(type) \ | ||||||
template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \ | ||||||
|
||||||
#if !defined(GGML_CUDA_FORCE_CUBLAS) | ||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q4_0); | ||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1); | ||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0); | ||||||
|
@@ -2924,6 +2925,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S); | |||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); | ||||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); | ||||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); | ||||||
#endif | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
// ------------------------------------------------------------------------------------------------------------------------- | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.