Skip to content

Commit

Permalink
Fix Gemma-2 multiple eos/bos ids (#87)
Browse files Browse the repository at this point in the history
Fix Gemma-2 multiple eos/bos ids
  • Loading branch information
guoqingbao authored Aug 22, 2024
1 parent c170b23 commit 124fadc
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/openai/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@ use super::Config;
use crate::openai::models::linear::{
linear_b_x as linear_b, linear_no_bias_x as linear, LinearX as Linear,
};
use crate::openai::models::TokenID;
use crate::paged_attention::input_metadata::InputMetadata;
use crate::paged_attention::PagedAttention;
use crate::SpecificConfig;
use candle::{DType, Device, IndexOp, Module, Result, Tensor};
use candle_core as candle;
use candle_nn::Activation;
use candle_nn::{RmsNorm, VarBuilder};
use either::Either;
use std::iter::zip;
use std::sync::Arc;

#[derive(serde::Deserialize, Debug, Clone)]
pub struct GemmaConfig {
pub attention_bias: bool,
Expand All @@ -27,8 +28,8 @@ pub struct GemmaConfig {
pub rms_norm_eps: f64,
pub rope_theta: f64,
pub vocab_size: usize,
pub bos_token_id: usize,
pub eos_token_id: usize,
pub bos_token_id: TokenID,
pub eos_token_id: TokenID,
pub max_position_embeddings: Option<usize>,
pub attn_logit_softcapping: Option<f64>,
pub final_logit_softcapping: Option<f64>,
Expand Down Expand Up @@ -63,8 +64,8 @@ impl GemmaConfig {
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
bos_token_id: super::TokenID(Either::Left(Some(self.bos_token_id as u32))),
eos_token_id: super::TokenID(Either::Left(Some(self.eos_token_id as u32))),
bos_token_id: self.bos_token_id,
eos_token_id: self.eos_token_id,
max_seq_len: self.max_position_embeddings.unwrap_or(4096),
sliding_window: None,
hidden_act,
Expand Down Expand Up @@ -198,7 +199,6 @@ struct Attention {
num_kv_heads: usize,
head_dim: usize,
rotary_emb: Arc<RotaryEmbedding>,
hidden_size: usize,
attn: PagedAttention,
}

Expand Down Expand Up @@ -246,7 +246,6 @@ impl Attention {
num_kv_heads,
head_dim,
rotary_emb,
hidden_size: hidden_sz,
attn: PagedAttention::new(
cfg.num_attention_heads,
head_dim,
Expand Down

0 comments on commit 124fadc

Please sign in to comment.