Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model support for gemma 9b #79

Open
sigridjineth opened this issue Aug 14, 2024 · 16 comments
Open

Add model support for gemma 9b #79

sigridjineth opened this issue Aug 14, 2024 · 16 comments
Assignees
Labels
enhancement New feature or request

Comments

@sigridjineth
Copy link

Symptoms

I found that using google/gemma-9b-it raises an error by stating that below.

(Some(_), Some(_)) => panic!("both hidden_act and hidden_activation are set"),
@sigridjineth sigridjineth changed the title Add model support for gemma 9b. Add model support for gemma 9b Aug 14, 2024
@guoqingbao
Copy link
Collaborator

Symptoms

I found that using google/gemma-9b-it raises an error by stating that below.

(Some(_), Some(_)) => panic!("both hidden_act and hidden_activation are set"),

Normally, only one of hidden_act or hidden_activation is set in the config file. I've revised the code to prioritize using hidden_act when both are set, instead of causing a panic. This should resolve the issue now, please refer this PR #82.

@sigridjineth
Copy link
Author

@guoqingbao thanks for checking this, but I am still getting the following error.

root@3359be8d598e:/workspace/sigrid/candle-vllm# cargo run --release -- --port 2000 --model-id rtzr/ko-gemma-2-9b-it gemma
warning: value assigned to `group_idx` is never read
   --> src/openai/pipelines/pipeline.rs:453:21
    |
453 |             let mut group_idx = 0;
    |                     ^^^^^^^^^
    |
    = help: maybe it is overwritten before being read?
    = note: `#[warn(unused_assignments)]` on by default

warning: field `config` is never read
  --> src/openai/pipelines/pipeline.rs:65:5
   |
56 | pub struct DefaultPipeline {
   |            --------------- field in this struct
...
65 |     config: Config,
   |     ^^^^^^
   |
   = note: `#[warn(dead_code)]` on by default

warning: `candle-vllm` (lib) generated 2 warnings
    Finished `release` profile [optimized] target(s) in 0.21s
     Running `target/release/candle-vllm --port 2000 --model-id rtzr/ko-gemma-2-9b-it gemma`
both hidden_act and hidden_activation are set
Model Config { hidden_size: 3584, intermediate_size: 14336, vocab_size: 256000, num_hidden_layers: 42, num_attention_heads: 16, num_key_value_heads: 8, use_flash_attn: false, rms_norm_eps: 1e-6, rope_theta: 10000.0, bos_token_id: TokenID(Left(Some(2))), eos_token_id: TokenID(Left(Some(1))), max_seq_len: 8192, sliding_window: None, hidden_act: Some(GeluPytorchTanh), tie_word_embeddings: false, rope_scaling: None, original_max_position_embeddings: None, attention_bias: false, partial_rotary_factor: None, qk_layer_rms_norm: None, kv_cache_dtype: BF16, use_qkv_bias: None, custom_stop_tokens: None, specific_config: SpecificConfig { repeat_last_n: None, temperature: None, top_k: None, top_p: None, penalty: None, max_gen_tokens: None, quant: None } }
Loading gemma model.
Error: APIError { data: "shape mismatch for model.layers.0.self_attn.q_proj.weight, expected: [3584, 3584], got: [4096, 3584]" }
root@3359be8d598e:/workspace/sigrid/candle-vllm#

@guoqingbao
Copy link
Collaborator

@guoqingbao thanks for checking this, but I am still getting the following error.

root@3359be8d598e:/workspace/sigrid/candle-vllm# cargo run --release -- --port 2000 --model-id rtzr/ko-gemma-2-9b-it gemma
warning: value assigned to `group_idx` is never read
   --> src/openai/pipelines/pipeline.rs:453:21
    |
453 |             let mut group_idx = 0;
    |                     ^^^^^^^^^
    |
    = help: maybe it is overwritten before being read?
    = note: `#[warn(unused_assignments)]` on by default

warning: field `config` is never read
  --> src/openai/pipelines/pipeline.rs:65:5
   |
56 | pub struct DefaultPipeline {
   |            --------------- field in this struct
...
65 |     config: Config,
   |     ^^^^^^
   |
   = note: `#[warn(dead_code)]` on by default

warning: `candle-vllm` (lib) generated 2 warnings
    Finished `release` profile [optimized] target(s) in 0.21s
     Running `target/release/candle-vllm --port 2000 --model-id rtzr/ko-gemma-2-9b-it gemma`
both hidden_act and hidden_activation are set
Model Config { hidden_size: 3584, intermediate_size: 14336, vocab_size: 256000, num_hidden_layers: 42, num_attention_heads: 16, num_key_value_heads: 8, use_flash_attn: false, rms_norm_eps: 1e-6, rope_theta: 10000.0, bos_token_id: TokenID(Left(Some(2))), eos_token_id: TokenID(Left(Some(1))), max_seq_len: 8192, sliding_window: None, hidden_act: Some(GeluPytorchTanh), tie_word_embeddings: false, rope_scaling: None, original_max_position_embeddings: None, attention_bias: false, partial_rotary_factor: None, qk_layer_rms_norm: None, kv_cache_dtype: BF16, use_qkv_bias: None, custom_stop_tokens: None, specific_config: SpecificConfig { repeat_last_n: None, temperature: None, top_k: None, top_p: None, penalty: None, max_gen_tokens: None, quant: None } }
Loading gemma model.
Error: APIError { data: "shape mismatch for model.layers.0.self_attn.q_proj.weight, expected: [3584, 3584], got: [4096, 3584]" }
root@3359be8d598e:/workspace/sigrid/candle-vllm#

I have checked the structure of gemma-2 and found it is different from gemma that we currently supported. We will add support for gemma-2 in the later updates, something like porting candle gemma-2 to candle-vllm huggingface/candle@c1b9e07

@guoqingbao guoqingbao added the enhancement New feature or request label Aug 21, 2024
@guoqingbao guoqingbao self-assigned this Aug 21, 2024
@EricLBuehler EricLBuehler self-assigned this Aug 21, 2024
@EricLBuehler
Copy link
Owner

EricLBuehler commented Aug 21, 2024

@guoqingbao @sigridjineth we can perhaps base this work on EricLBuehler/mistral.rs#490 and EricLBuehler/mistral.rs#554.

@guoqingbao
Copy link
Collaborator

@guoqingbao @sigridjineth we can perhaps base this work on EricLBuehler/mistral.rs#490 and EricLBuehler/mistral.rs#554.

Sure!

@EricLBuehler
Copy link
Owner

@guoqingbao it looks like we will need to update PA kernels for softcapping though.

@guoqingbao
Copy link
Collaborator

@guoqingbao it looks like we will need to update PA kernels for softcapping though.

Have you encounter problems for gemma-2 in Mistral.rs using the current PA kernels?

@EricLBuehler
Copy link
Owner

Have you encounter problems for gemma-2 in Mistral.rs using the current PA kernels?

I actually disabled PA for that case, so I haven't looked into it but we can add it here!

@guoqingbao
Copy link
Collaborator

Have you encounter problems for gemma-2 in Mistral.rs using the current PA kernels?

I actually disabled PA for that case, so I haven't looked into it but we can add it here!

I found the introduction of soft-capping in Google's Gemma-2 release, it seems that soft-capping is designed for training, and I'm not sure if it is neccessary in the inference:

https://huggingface.co/blog/gemma2

"Soft capping is a technique that prevents logits from growing excessively large without truncating them. It works by dividing the logits by a maximum value threshold (soft_cap), then passing them through a tanh layer (ensuring they are in the (-1, 1) range), and finally multiplying by the threshold again. This guarantees that the final values will be in the (-soft_cap, +soft_cap) interval without losing much information but stabilizing the training."

@EricLBuehler
Copy link
Owner

Have you encounter problems for gemma-2 in Mistral.rs using the current PA kernels?

I actually disabled PA for that case, so I haven't looked into it but we can add it here!

I found the introduction of soft-capping in Google's Gemma-2 release, it seems that soft-capping is designed for training, and I'm not sure if it is neccessary in the inference:

https://huggingface.co/blog/gemma2

"Soft capping is a technique that prevents logits from growing excessively large without truncating them. It works by dividing the logits by a maximum value threshold (soft_cap), then passing them through a tanh layer (ensuring they are in the (-1, 1) range), and finally multiplying by the threshold again. This guarantees that the final values will be in the (-soft_cap, +soft_cap) interval without losing much information but stabilizing the training."

Yeah, I agree. This would probably be fine to exclude, looks like vLLM does the same and doesn't even use the sliding window interleaving (they use global attention everywhere). I'll add the PR shortly!

@guoqingbao
Copy link
Collaborator

Have you encounter problems for gemma-2 in Mistral.rs using the current PA kernels?

I actually disabled PA for that case, so I haven't looked into it but we can add it here!

I found the introduction of soft-capping in Google's Gemma-2 release, it seems that soft-capping is designed for training, and I'm not sure if it is neccessary in the inference:
https://huggingface.co/blog/gemma2
"Soft capping is a technique that prevents logits from growing excessively large without truncating them. It works by dividing the logits by a maximum value threshold (soft_cap), then passing them through a tanh layer (ensuring they are in the (-1, 1) range), and finally multiplying by the threshold again. This guarantees that the final values will be in the (-soft_cap, +soft_cap) interval without losing much information but stabilizing the training."

Yeah, I agree. This would probably be fine to exclude, looks like vLLM does the same and doesn't even use the sliding window interleaving (they use global attention everywhere). I'll add the PR shortly!

Yes, I also checked the newest version of vLLM and didn't found their implementation of soft-capping for attention. It seems the feature (soft-capping) is only available in Jax for training.

@guoqingbao
Copy link
Collaborator

@sigridjineth As discussed in #84 , Gemma-2 models are now supported, please refer to #86. Please feel free to report if there is any other issues.

@sigridjineth
Copy link
Author

@guoqingbao Thank you very much!

@sigridjineth
Copy link
Author

@guoqingbao I found that gemma-2b does not support at this moment then. I will work on that on my spare time. thanks for your attention!

root@3359be8d598e:/workspace/sigrid/candle-vllm# cargo run --release -- --port 2000 --model-id google/gemma-2-2b-it gemma
warning: field `hidden_size` is never read
   --> src/openai/models/gemma.rs:201:5
    |
192 | struct Attention {
    |        --------- field in this struct
...
201 |     hidden_size: usize,
    |     ^^^^^^^^^^^
    |
    = note: `#[warn(dead_code)]` on by default

warning: `candle-vllm` (lib) generated 1 warning
    Finished `release` profile [optimized] target(s) in 0.21s
     Running `target/release/candle-vllm --port 2000 --model-id google/gemma-2-2b-it gemma`
tokenizer.json [00:00:00] [█████████████████████████████████████████] 16.71 MiB/16.71 MiB 26.86 MiB/s (0s)
config.json [00:00:00] [█████████████████████████████████████████████████████] 838 B/838 B 4.16 KiB/s (0s)
..del-00001-of-00002.safetensors [00:00:58] [█████████████████████████] 4.65 GiB/4.65 GiB 81.62 MiB/s (0s)
..del-00002-of-00002.safetensors [00:00:03] [█████████████████████] 229.54 MiB/229.54 MiB 69.68 MiB/s (0s)Error: APIError { data: "invalid type: sequence, expected usize at line 10 column 18" }

@EricLBuehler
Copy link
Owner

@sigridjineth this is because the eos tok id's are am array for 2b, we have a struct for this which implements Deserialize if you want to add a PR!

@guoqingbao
Copy link
Collaborator

@guoqingbao I found that gemma-2b does not support at this moment then. I will work on that on my spare time. thanks for your attention!

root@3359be8d598e:/workspace/sigrid/candle-vllm# cargo run --release -- --port 2000 --model-id google/gemma-2-2b-it gemma
warning: field `hidden_size` is never read
   --> src/openai/models/gemma.rs:201:5
    |
192 | struct Attention {
    |        --------- field in this struct
...
201 |     hidden_size: usize,
    |     ^^^^^^^^^^^
    |
    = note: `#[warn(dead_code)]` on by default

warning: `candle-vllm` (lib) generated 1 warning
    Finished `release` profile [optimized] target(s) in 0.21s
     Running `target/release/candle-vllm --port 2000 --model-id google/gemma-2-2b-it gemma`
tokenizer.json [00:00:00] [█████████████████████████████████████████] 16.71 MiB/16.71 MiB 26.86 MiB/s (0s)
config.json [00:00:00] [█████████████████████████████████████████████████████] 838 B/838 B 4.16 KiB/s (0s)
..del-00001-of-00002.safetensors [00:00:58] [█████████████████████████] 4.65 GiB/4.65 GiB 81.62 MiB/s (0s)
..del-00002-of-00002.safetensors [00:00:03] [█████████████████████] 229.54 MiB/229.54 MiB 69.68 MiB/s (0s)Error: APIError { data: "invalid type: sequence, expected usize at line 10 column 18" }

This PR should address the problem 124fadc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants