Skip to content

Commit

Permalink
Update the flash attn kernels.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Jul 15, 2024
1 parent d74fbed commit 047be5b
Show file tree
Hide file tree
Showing 51 changed files with 2,275 additions and 900 deletions.
18 changes: 17 additions & 1 deletion candle-flash-attn/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use anyhow::{Context, Result};
use std::path::PathBuf;

const KERNEL_FILES: [&str; 17] = [
const KERNEL_FILES: [&str; 33] = [
"kernels/flash_api.cu",
"kernels/flash_fwd_hdim128_fp16_sm80.cu",
"kernels/flash_fwd_hdim160_fp16_sm80.cu",
Expand All @@ -22,6 +22,22 @@ const KERNEL_FILES: [&str; 17] = [
"kernels/flash_fwd_hdim32_bf16_sm80.cu",
"kernels/flash_fwd_hdim64_bf16_sm80.cu",
"kernels/flash_fwd_hdim96_bf16_sm80.cu",
"kernels/flash_fwd_hdim128_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim160_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim192_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim224_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim256_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim32_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim64_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim96_fp16_causal_sm80.cu",
"kernels/flash_fwd_hdim128_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim160_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim192_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim224_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim256_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim32_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim64_bf16_causal_sm80.cu",
"kernels/flash_fwd_hdim96_bf16_causal_sm80.cu",
];

fn main() -> Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion candle-flash-attn/cutlass
Submodule cutlass updated 2187 files
78 changes: 45 additions & 33 deletions candle-flash-attn/kernels/alibi.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,50 +13,62 @@ using namespace cute;

////////////////////////////////////////////////////////////////////////////////////////////////////

template <bool Is_causal, typename Engine, typename Layout>
inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_,
const int max_seqlen_k,
const int row_idx_offset,
const int max_seqlen_q,
const int warp_row_stride,
const float alibi_slope) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
template <bool Is_causal>
struct Alibi {

const float alibi_slope;
const int max_seqlen_k, max_seqlen_q;

__forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
: alibi_slope(alibi_slope)
, max_seqlen_k(max_seqlen_k)
, max_seqlen_q(max_seqlen_q) {
};


template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_,
const int row_idx_offset,
const int warp_row_stride) {
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
}
}
}
} else { // Bias depends on both row_idx and col_idx
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
} else { // Bias depends on both row_idx and col_idx
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
}
}
}
}
}
}
}

};

} // namespace flash
4 changes: 2 additions & 2 deletions candle-flash-attn/kernels/block_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ struct BlockInfo {
}

template <typename index_t>
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}

template <typename index_t>
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
}

Expand Down
94 changes: 94 additions & 0 deletions candle-flash-attn/kernels/dropout.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/

#pragma once

#include "philox.cuh"
#include "utils.h"

namespace flash {

struct Dropout {

const unsigned long long seed, offset;
const uint8_t p_dropout_in_uint8_t;

__forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset,
const uint8_t p_dropout_in_uint8_t,
const int bid, const int hid, const int tid, const int nheads)
: seed(seed)
, offset(offset + (bid * nheads + hid) * 32 + tid % 32)
, p_dropout_in_uint8_t(p_dropout_in_uint8_t) {
}

template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
__forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
int block_row_start, int block_col_start, int block_row_stride) {
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout()));
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
};
static_assert(decltype(size<2>(tensor))::value % 2 == 0);
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
#pragma unroll
for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
uint2 rowcol = make_uint2(block_row_start, block_col_start);
#pragma unroll
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
// Special implementation for 16-bit types: we duplicate the threshold to the
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
// the random value is less than the threshold.
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
// We're exploiting the fact that floating point comparison is equivalent to integer
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
if (!encode_dropout_in_sign_bit
&& (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) {
uint16_t rnd_16[16];
#pragma unroll
for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
#pragma unroll
for (int j = 0; j < 2; j++) {
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
#pragma unroll
for (int i = 0; i < 4; i++) {
uint32_t mask;
asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
tensor_uint32(i) &= mask;
}
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
} else {
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int i = 0; i < 8; i++) {
tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
}
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
}
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
// // }
}
}
}

};

} // namespace flash
8 changes: 8 additions & 0 deletions candle-flash-attn/kernels/error.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#pragma once

#define C10_CUDA_CHECK(EXPR) \
do { \
const cudaError_t __err = EXPR; \
} while (0)

#define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError())
35 changes: 29 additions & 6 deletions candle-flash-attn/kernels/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,22 @@
#include <cuda.h>
#include <vector>

// #ifdef OLD_GENERATOR_PATH
// #include <ATen/CUDAGeneratorImpl.h>
// #else
// #include <ATen/cuda/CUDAGeneratorImpl.h>
// #endif
//
// #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack

constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Qkv_params {
using index_t = uint32_t;
using index_t = int64_t;
// The QKV matrices.
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
Expand Down Expand Up @@ -59,7 +67,7 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ softmax_lseaccum_ptr;

// The dimensions.
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q;

// The scaling factors for the kernel.
float scale_softmax;
Expand Down Expand Up @@ -91,7 +99,12 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ rotary_sin_ptr;

// The indices to index into the KV cache.
int *__restrict__ cache_batch_idx;
int * __restrict__ cache_batch_idx;

// Paged KV cache
int * __restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;

// The dropout probability (probability of keeping an activation).
float p_dropout;
Expand All @@ -105,6 +118,13 @@ struct Flash_fwd_params : public Qkv_params {

// Local window size
int window_size_left, window_size_right;
float softcap;

// Random state.
// at::PhiloxCudaState philox_args;

// Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t * rng_state;

bool is_bf16;
bool is_causal;
Expand All @@ -119,6 +139,9 @@ struct Flash_fwd_params : public Qkv_params {

void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;

bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -165,7 +188,7 @@ struct Flash_bwd_params : public Flash_fwd_params {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);

template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
20 changes: 10 additions & 10 deletions candle-flash-attn/kernels/flash_api.cu
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
#include "kernels.h"
#include "kernel_helpers.h"
#include "flash_fwd_launch_template.h"

void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
FP16_SWITCH(!params.is_bf16, [&] {
FWD_HEADDIM_SWITCH(params.d, [&] {
// if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
// } else {
// run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
// }
});
});
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
FP16_SWITCH(!params.is_bf16, [&] {
HEADDIM_SWITCH(params.d, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
});
});
});
}

extern "C" void run_mha(
Expand Down
10 changes: 10 additions & 0 deletions candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

template<>
void run_mha_fwd_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
}
4 changes: 2 additions & 2 deletions candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"

template<>
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
void run_mha_fwd_<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::bfloat16_t, false>(params, stream);
}
10 changes: 10 additions & 0 deletions candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"

#include "flash_fwd_launch_template.h"

template<>
void run_mha_fwd_<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::half_t, true>(params, stream);
}
4 changes: 2 additions & 2 deletions candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"

template<>
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
void run_mha_fwd_<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::half_t, false>(params, stream);
}
Loading

0 comments on commit 047be5b

Please sign in to comment.