Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/huggingface/candle into Com…
Browse files Browse the repository at this point in the history
…mandEncoderReuse
  • Loading branch information
tomsanbear committed Apr 11, 2024
2 parents 79647c8 + a0460cd commit 912755d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
12 changes: 12 additions & 0 deletions candle-examples/examples/gemma/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ enum Which {
InstructV1_1_2B,
#[value(name = "1.1-7b-it")]
InstructV1_1_7B,
#[value(name = "code-2b")]
CodeBase2B,
#[value(name = "code-7b")]
CodeBase7B,
#[value(name = "code-2b-it")]
CodeInstruct2B,
#[value(name = "code-7b-it")]
CodeInstruct7B,
}

struct TextGeneration {
Expand Down Expand Up @@ -224,6 +232,10 @@ fn main() -> Result<()> {
Which::Base7B => "google/gemma-7b".to_string(),
Which::Instruct2B => "google/gemma-2b-it".to_string(),
Which::Instruct7B => "google/gemma-7b-it".to_string(),
Which::CodeBase2B => "google/codegemma-2b".to_string(),
Which::CodeBase7B => "google/codegemma-7b".to_string(),
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
Expand Down
19 changes: 15 additions & 4 deletions candle-transformers/src/models/gemma.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{linear_b as linear, Linear, VarBuilder};
use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder};

fn default_max_position_embeddings() -> usize {
4096
Expand All @@ -11,8 +11,9 @@ fn default_max_position_embeddings() -> usize {
pub struct Config {
pub attention_bias: bool,
pub head_dim: usize,
#[serde(alias = "hidden_activation")]
pub hidden_act: candle_nn::Activation,
// The code gemma configs include both hidden_act and hidden_activation.
pub hidden_act: Option<Activation>,
pub hidden_activation: Option<Activation>,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_attention_heads: usize,
Expand All @@ -26,6 +27,16 @@ pub struct Config {
pub max_position_embeddings: usize,
}

impl Config {
fn hidden_act(&self) -> Result<Activation> {
match (self.hidden_act, self.hidden_activation) {
(None, Some(act)) | (Some(act), None) => Ok(act),
(Some(_), Some(_)) => candle::bail!("both hidden_act and hidden_activation are set"),
(None, None) => candle::bail!("none of hidden_act and hidden_activation are set"),
}
}
}

#[derive(Debug, Clone)]
struct RmsNorm {
weight: Tensor,
Expand Down Expand Up @@ -127,7 +138,7 @@ impl MLP {
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_act,
act_fn: cfg.hidden_act()?,
})
}
}
Expand Down

0 comments on commit 912755d

Please sign in to comment.