Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom benchmark with parameters #88

Merged
merged 18 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ After the `candle-vllm` service is running, run the Python script and enjoy effi

## Batched requests

``` shell
python3 examples/benchmark.py --batch 16 --max_tokens 1024
```
Refer to `examples/benchmark.py`

``` python
Expand Down
50 changes: 26 additions & 24 deletions examples/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,26 @@
from openai import Stream
from openai.types.chat import ChatCompletionChunk
from typing import List
# Run: cargo run --release -- --port 2000 --model-id <MODEL_ID> <MODEL_TYPE> --repeat-last-n 64
import argparse
# Run candle-vllm service: cargo run --release -- --port 2000 --model-id <MODEL_ID> <MODEL_TYPE> --repeat-last-n 64
# MODEL_ID is the huggingface model id or local weight path
# MODEL_TYPE is one of ["llama", "llama3", "mistral", "phi2", "phi3", "qwen2", "gemma", "yi", "stable-lm"]

# Then run this file: python3 examples/benchmark.py --batch 16

openai.api_key = "EMPTY"

openai.base_url = "http://localhost:2000/v1/"

# You may add your custom prompts here
PROMPT_CANDIDATES = ["Explain how to best learn Rust.",
"Please talk about deep learning.",
"Do you know the capital city of China? Talk the details of you known.",
"Who is the best female actor in the world? Explain why.",
"Let me know how to deal with depression?",
"How to make money in short time?",
"What is the future trend of large language model?",
"The famous tech companies in the world."]

async def chat_completion(model, max_tokens, prompt):
completion = openai.chat.completions.create(
model=model,
Expand All @@ -34,30 +45,16 @@ async def stream_response(response_idx, stream: Stream[ChatCompletionChunk]):
result += r
return (response_idx, result)

async def benchmark():
model = "mistral7b"
max_tokens = 1024
# 16 requests
prompts = ["Explain how to best learn Rust.",
"Please talk about deep learning.",
"Do you know the capital city of China? Talk the details of you known.",
"Who is the best female actor in the world? Explain why.",
"Let me know how to deal with depression?",
"How to make money in short time?",
"What is the future trend of large language model?",
"The famous tech companies in the world.",
"Explain how to best learn Rust.",
"Please talk about deep learning.",
"Do you know the capital city of China? Talk the details of you known.",
"Who is the best female actor in the world? Explain why.",
"Let me know how to deal with depression?",
"How to make money in short time?",
"What is the future trend of large language model?",
"The famous tech companies in the world."]
async def benchmark(batch, max_tokens=1024):
model = "any" # model used dependent on the server side
# candidate requests
prompts = []
for i in range(batch):
prompts.append(PROMPT_CANDIDATES[i % len(PROMPT_CANDIDATES)])

# avoid generating very short answers
for i in range(len(prompts)):
prompts[i] = prompts[i] + " Respond in more than {} words.".format((int(max_tokens / 10) + 1) * 10)
prompts[i] = prompts[i] + " Respond in more than {} words.".format(int(max_tokens / 10) * 10)

# send 16 chat requests at the same time
tasks: List[asyncio.Task] = []
Expand Down Expand Up @@ -86,4 +83,9 @@ async def benchmark():
print("\n\n Response {}: \n\n {}".format(idx, output))


asyncio.run(benchmark())
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Using 'batch' and 'max_tokens' parameters for candle-vllm benchmark.")
parser.add_argument('--batch', default=16, type=int)
parser.add_argument('--max_tokens', default=1024, type=int)
args = parser.parse_args()
asyncio.run(benchmark(args.batch, args.max_tokens))
2 changes: 2 additions & 0 deletions kernels/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ extern "C" {
kv_head_stride: c_int,

dtype: u32,
softscapping: f32,
);

pub fn paged_attention_v2(
Expand All @@ -66,5 +67,6 @@ extern "C" {
kv_head_stride: c_int,

dtype: u32,
softscapping: f32,
);
}
2 changes: 1 addition & 1 deletion kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ pub const COPY_BLOCKS_KERNEL: &str =
pub const PAGEDATTENTION: &str = include_str!(concat!(env!("OUT_DIR"), "/pagedattention.ptx"));
pub const RESHAPE_AND_CACHE_KERNEL: &str =
include_str!(concat!(env!("OUT_DIR"), "/reshape_and_cache_kernel.ptx"));
pub mod ffi;
pub mod ffi;
56 changes: 42 additions & 14 deletions kernels/src/pagedattention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ inline __device__ float block_sum(float* red_smem, float sum) {
return VLLM_SHFL_SYNC(sum, 0);
}

inline __device__ float fast_tanh(float x) {
#if defined(__CUDA_ARCH__)
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750)
float y;
asm volatile ( "tanh.approx.f32 %0, %1; " : "=f"(y) : "f"(x));
return y;
#else
return ::tanhf(x);
#endif
#else
return std::tanh(x);
#endif
}

// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template<
Expand All @@ -96,7 +110,8 @@ __device__ void paged_attention_kernel(
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
const int kv_head_stride,
const float softscapping) {
const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z;
Expand Down Expand Up @@ -212,6 +227,10 @@ __device__ void paged_attention_kernel(
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);

if (softscapping != 1.0) {
qk = fast_tanh(qk / softscapping) * softscapping;
}
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;

Expand Down Expand Up @@ -409,11 +428,12 @@ __global__ void paged_attention_v1_kernel(
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
const int kv_head_stride,
const float softscapping) {
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
/* exp_sums */ nullptr, /* max_logits */ nullptr,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, softscapping);
}

// Grid: (num_heads, num_seqs, max_num_partitions).
Expand All @@ -438,11 +458,12 @@ __global__ void paged_attention_v2_kernel(
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
const int kv_head_stride,
const float softscapping) {
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
q_stride, kv_block_stride, kv_head_stride);
q_stride, kv_block_stride, kv_head_stride, softscapping);
}

// Grid: (num_heads, num_seqs).
Expand Down Expand Up @@ -564,7 +585,8 @@ __global__ void paged_attention_v2_reduce_kernel(
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride);
kv_head_stride,\
softscapping);

// TODO(woosuk): Tune NUM_THREADS.
template<
Expand All @@ -588,7 +610,8 @@ void paged_attention_v1_launcher(
int max_num_blocks_per_seq,
int q_stride,
int kv_block_stride,
int kv_head_stride
int kv_head_stride,
float softscapping
) {

// int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
Expand Down Expand Up @@ -652,7 +675,8 @@ void paged_attention_v1_launcher(
max_num_blocks_per_seq, \
q_stride, \
kv_block_stride, \
kv_head_stride);
kv_head_stride, \
softscapping);

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
Expand Down Expand Up @@ -691,7 +715,8 @@ extern "C" void paged_attention_v1(
int32_t kv_block_stride,
int32_t kv_head_stride,

uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32
uint32_t dtype, // 0 => f16; 1 => bf16; 2 => f32
float softscapping
) {
if (dtype == 2) {
CALL_V1_LAUNCHER_BLOCK_SIZE(float);
Expand Down Expand Up @@ -719,7 +744,8 @@ extern "C" void paged_attention_v1(
alibi_slopes, \
q_stride, \
kv_block_stride, \
kv_head_stride); \
kv_head_stride,\
softscapping); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
reinterpret_cast<T*>(out), \
Expand Down Expand Up @@ -754,8 +780,8 @@ void paged_attention_v2_launcher(
int max_num_blocks_per_seq,
int q_stride,
int kv_block_stride,
int kv_head_stride

int kv_head_stride,
float softscapping
) {
// int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);

Expand Down Expand Up @@ -825,7 +851,8 @@ void paged_attention_v2_launcher(
max_num_blocks_per_seq, \
q_stride, \
kv_block_stride, \
kv_head_stride);
kv_head_stride,\
softscapping);

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
Expand Down Expand Up @@ -867,7 +894,8 @@ extern "C" void paged_attention_v2(
int32_t kv_block_stride,
int32_t kv_head_stride,

uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32
uint32_t dtype, // 0 => f16; 1 => bf16; 2 => f32
float softscapping
) {
if (dtype == 2) {
CALL_V2_LAUNCHER_BLOCK_SIZE(float);
Expand Down
6 changes: 5 additions & 1 deletion src/backend/paged_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::ffi::c_int;

struct PagedAttention {
softmax_scale: f32,

softcapping: f32,
key_cache: Tensor,
value_cache: Tensor,
block_tables: Tensor,
Expand Down Expand Up @@ -187,6 +187,7 @@ impl PagedAttention {
kv_block_stride as c_int,
kv_head_stride as c_int,
internal_type,
self.softcapping,
)
}
} else {
Expand Down Expand Up @@ -223,6 +224,7 @@ impl PagedAttention {
kv_block_stride as c_int,
kv_head_stride as c_int,
internal_type,
self.softcapping,
)
}
}
Expand Down Expand Up @@ -277,6 +279,7 @@ pub fn paged_attention(
context_lens: &Tensor,
max_context_len: usize,
softmax_scale: f32,
softcapping: f32,
) -> Result<Tensor> {
let op = PagedAttention {
softmax_scale,
Expand All @@ -285,6 +288,7 @@ pub fn paged_attention(
block_tables: block_tables.clone(),
context_lens: context_lens.clone(),
max_context_len,
softcapping,
};
q.apply_op1(op)
}
Expand Down
Loading
Loading