Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Update gpt2 to use wte if no lm_head #362

Merged
merged 3 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,11 @@ impl Context {

/// Creates a 1D view over `a`.
pub fn op_view_1d(&self, a: &Tensor, ne0: usize, offset: usize) -> Tensor {
#[cfg(debug_assertions)]
assert!(
offset < a.nbytes(),
"Cannot create tensor view with offset larger than tensor"
);
let tensor = unsafe {
sys::ggml_view_1d(self.ptr.as_ptr(), a.ptr.as_ptr(), usize_to_i64(ne0), offset)
};
Expand Down
14 changes: 10 additions & 4 deletions crates/models/gpt2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ pub struct Gpt2 {
// weighted positional encodings
wpe: Tensor,
// language model head
lm_head: Tensor,
//
// Optional: if not present, the `wte` tensor is used instead.
lm_head: Option<Tensor>,

// weights for the model
layers: Vec<Layer>,
Expand Down Expand Up @@ -59,7 +61,10 @@ impl KnownModel for Gpt2 {
let ln_f_b = tl.load("model/ln_f/b")?;
let wte = tl.load("model/wte")?;
let wpe = tl.load("model/wpe")?;
let lm_head = tl.load("model/lm_head")?;

// GPT-2's language model head is optional; if it is not present,
// the `wte` tensor is used instead.
let lm_head = tl.load("model/lm_head").ok();

let mut layers = Vec::new();
for i in 0..hyperparameters.n_layer {
Expand Down Expand Up @@ -102,7 +107,7 @@ impl KnownModel for Gpt2 {
fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
InferenceSession::new(
config,
self.hyperparameters.n_ctx,
self.context_size,
self.hyperparameters.n_layer,
self.hyperparameters.n_embd,
self.hyperparameters.n_vocab,
Expand Down Expand Up @@ -306,7 +311,8 @@ impl KnownModel for Gpt2 {

let embeddings_tensor: ggml::Tensor = input_layer.share();

input_layer = ctx0.op_mul_mat(&self.lm_head, &input_layer);
let head = self.lm_head.as_ref().unwrap_or(&self.wte);
input_layer = ctx0.op_mul_mat(head, &input_layer);

(
gf,
Expand Down
2 changes: 1 addition & 1 deletion crates/models/gptj/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl KnownModel for GptJ {
fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
InferenceSession::new(
config,
self.hyperparameters.n_ctx,
self.context_size,
self.hyperparameters.n_layer,
self.hyperparameters.n_embd,
self.hyperparameters.n_vocab,
Expand Down
2 changes: 1 addition & 1 deletion crates/models/gptneox/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl KnownModel for GptNeoX {
fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
InferenceSession::new(
config,
self.hyperparameters.n_ctx,
self.context_size,
self.hyperparameters.n_layer,
self.hyperparameters.n_embd,
self.hyperparameters.n_vocab,
Expand Down