Skip to content

Commit

Permalink
block_attn xqa_optim supoort qwen2 (#67526)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome authored Aug 19, 2024
1 parent 128ae87 commit 4c44259
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions paddle/phi/kernels/fusion/gpu/block_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,36 @@ void dispatch_blha_gqa_kernel(const Block_AttN_params<T> &params,
stream,
load_func,
store_func)
} else if (params.gqa_num_per_partitions == 6) {
constexpr int THDS_PER_BLOCK = 1024;
BLHA_LAUNCH_GQA_KERNEL(T,
Dh,
Dh_MAX,
THREADS_PER_KEY,
THREADS_PER_VALUE,
THDS_PER_BLOCK,
BlockSize,
CACHE_TYPE,
6,
2,
stream,
load_func,
store_func)
} else if (params.gqa_num_per_partitions == 7) {
constexpr int THDS_PER_BLOCK = 1024;
BLHA_LAUNCH_GQA_KERNEL(T,
Dh,
Dh_MAX,
THREADS_PER_KEY,
THREADS_PER_VALUE,
THDS_PER_BLOCK,
BlockSize,
CACHE_TYPE,
7,
1,
stream,
load_func,
store_func)
} else if (params.gqa_num_per_partitions == 8) {
constexpr int THDS_PER_BLOCK = 1024;
BLHA_LAUNCH_GQA_KERNEL(T,
Expand Down Expand Up @@ -1701,6 +1731,7 @@ void blha(const phi::GPUContext &dev_ctx,
params.timestep = timestep + pre_cache_length;
params.inv_sqrt_dh = inv_sqrt_dh;
params.rotary_emb_dims = rotary_emb_dims;

VLOG(3) << "batch_size: " << batch_size << " q_num_head: " << q_num_head
<< " kv_num_head: " << kv_num_head << " block_size: " << block_size
<< " timestep: " << timestep;
Expand Down

0 comments on commit 4c44259

Please sign in to comment.