Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support vLLM-style rope #530

Closed
ByronHsu opened this issue Oct 15, 2024 · 5 comments
Closed

Support vLLM-style rope #530

ByronHsu opened this issue Oct 15, 2024 · 5 comments

Comments

@ByronHsu
Copy link

ByronHsu commented Oct 15, 2024

As part of SGLang Issue #1487, SGLang plans to move vLLM to optional dependencies and use flashinfer as the main dependency.

I am working on moving rope to flashinfer. My plan is to reuse most of the existing vllm rope but replace ops.rotary_embedding and ops.batch_rotary_embedding with flashinfer's kernel, which can be found here.

However, I've noticed some gaps between the vLLM and flashinfer implementations:

  1. Cos_sin_cache: vLLM pre-computes the cos_sin_cache in the constructor, whereas flashinfer computes it on-the-fly.
  2. Offsets and indptr instead of positions: It's tricky to convert positions back to offsets + indptr. Can we support positions directly?
  3. Partial rotate: vLLM supports a partial rotation where the rotary dimension is less than the head dimension.
  4. Batched rope for multi-Lora: For more context, see vLLM pull request #3095.

In general, we can prioritize the first three issues and consider the fourth as a stretch goal.

@yzh119
Copy link
Collaborator

yzh119 commented Oct 15, 2024

Hi @ByronHsu , thanks for your suggestions, I think 1 & 3 are easy to support.

For 1, we can adding sin_cache and cos_cache as optional fields to the rope apis. For long context, there might be some numerical issues with f16 sin/cos cache so we should also support f32 sin/cos cache (Our current on-the-fly sin/cos computation uses f32).

For 3, yes we can add another rope_dim field for partial rope.

Can you give a concrete example of 2?

@yzh119
Copy link
Collaborator

yzh119 commented Oct 15, 2024

Okay I think I understand 2 now, for example, if batch_size=3, and indptr=[0, 1, 5, 10], and offsets=[4, 6, 3].
Then a equivalent positions would be:
[4, 6, 7, 8, 9, 3, 4, 5, 6, 7].

Is that true?

@ByronHsu
Copy link
Author

Okay I think I understand 2 now, for example, if batch_size=3, and indptr=[0, 1, 5, 10], and offsets=[4, 6, 3].
Then a equivalent positions would be:
[4, 6, 7, 8, 9, 3, 4, 5, 6, 7].

Yes exactly! Thank you for the prompt response! All sounds good to me.

One comment: Can we separate the new API from the current 4 flashinfer's rope functions and provide the exact same interface with vLLM? Several reasons:

  1. apply_rope_inplace only implements the formula on the original paper, but in reality there are much more variants
  2. It makes all vllm's kernel users easy to migrate

Maybe we can call this apply_rope_inplace_with_cache, which does not calculate rope on the fly and support my proposed features

def apply_rope_inplace_with_cache(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
   ...

@ByronHsu
Copy link
Author

I did a global search and found ops.batch_rotary_embedding is not used in SGLang (looks like not in vLLM too). So we can safely skip 4th feature. thanks!

yzh119 added a commit that referenced this issue Oct 29, 2024
Previously our rope apis assume the position indices of each request is
contiguous, which is not appropriate for applications such as
speculative decoding, this PR fixes the issue by supporting the
huggingface transformer-style API which use `pos_ids` argument to
specify positions.

This PR implements parts of the feature of #530 , other requests are
coming in later PRs.

cc @dreaming-panda @abcdabcd987 @ByronHsu
yzh119 added a commit that referenced this issue Nov 5, 2024
As requested in #530 , this PR implements the RoPE with cached cos/sin
embeddings, which is more flexible in some use cases.

In our previous RoPE implementations, cos/sin values are computed
on-the-fly inside kernels with float32 instead using cached values.

In this PR we found that if we use f16 cos/sin cache, the rope result
will have a large discrepancy compared to our original implementation
`flashinfer.apply_rope` (which stores cos/sin with fp32). So we require
the `cos_cache` and `sin_cache` to use fp32 data type.

cc @dreaming-panda @ByronHsu
yzh119 added a commit that referenced this issue Nov 10, 2024
…599)

This PR implements the final piece of #530 , so that we can partially
apply rotary embedding to first head dimensions instead of entire head
dimensions.

We also add a simple benchmark for RoPE, below is the result on H100:
```python
batch_size:   1, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 23us, throughput:   0.876GB/s
batch_size:   1, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 26us, throughput:   0.801GB/s
batch_size:   1, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 27us, throughput:  95.735GB/s
batch_size:   1, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 27us, throughput:  95.639GB/s
batch_size:   1, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 31us, throughput: 672.889GB/s
batch_size:   1, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 32us, throughput: 662.972GB/s
---
batch_size:  19, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 27us, throughput:  14.559GB/s
batch_size:  19, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 27us, throughput:  14.435GB/s
batch_size:  19, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 37us, throughput: 1339.450GB/s
batch_size:  19, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 37us, throughput: 1340.399GB/s
batch_size:  19, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 148us, throughput: 2696.563GB/s
batch_size:  19, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 148us, throughput: 2689.104GB/s
---
batch_size:  99, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 27us, throughput:  74.186GB/s
batch_size:  99, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 27us, throughput:  74.452GB/s
batch_size:  99, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 110us, throughput: 2350.830GB/s
batch_size:  99, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 110us, throughput: 2359.814GB/s
batch_size:  99, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 717us, throughput: 2895.389GB/s
batch_size:  99, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 718us, throughput: 2891.385GB/s
---
batch_size: 128, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 27us, throughput:  95.449GB/s
batch_size: 128, append_len:     1, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 27us, throughput:  95.646GB/s
batch_size: 128, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 130us, throughput: 2576.101GB/s
batch_size: 128, append_len:   128, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 130us, throughput: 2582.447GB/s
batch_size: 128, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: False, latency: 924us, throughput: 2906.154GB/s
batch_size: 128, append_len:  1024, num_qo_heads:    32, num_kv_heads:     8, head_dim:   128, use_cos_sin_cache: True, latency: 925us, throughput: 2903.484GB/s
```
@yzh119
Copy link
Collaborator

yzh119 commented Nov 10, 2024

Done in #568, #585 and #599.

@yzh119 yzh119 closed this as completed Nov 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants