Skip to content

Commit

Permalink
Make the RNN configs accessible from the models.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Oct 4, 2024
1 parent 6faecaa commit d5cb827
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 74 deletions.
1 change: 0 additions & 1 deletion candle-examples/examples/encodec/audio_io.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#![allow(unused)]
use anyhow::{Context, Result};
use std::sync::{Arc, Mutex};

Expand Down
1 change: 0 additions & 1 deletion candle-examples/examples/mimi/audio_io.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#![allow(unused)]
use anyhow::{Context, Result};
use std::sync::{Arc, Mutex};

Expand Down
175 changes: 103 additions & 72 deletions candle-nn/src/rnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl LSTMConfig {
/// A Long Short-Term Memory (LSTM) layer.
///
/// <https://en.wikipedia.org/wiki/Long_short-term_memory>
#[allow(clippy::upper_case_acronyms, unused)]
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Debug)]
pub struct LSTM {
w_ih: Tensor,
Expand All @@ -129,54 +129,70 @@ pub struct LSTM {
dtype: DType,
}

impl LSTM {
/// Creates a LSTM layer.
pub fn new(
in_dim: usize,
hidden_dim: usize,
config: LSTMConfig,
vb: crate::VarBuilder,
) -> Result<Self> {
let layer_idx = config.layer_idx;
let direction_str = match config.direction {
Direction::Forward => "",
Direction::Backward => "_reverse",
};
let w_ih = vb.get_with_hints(
(4 * hidden_dim, in_dim),
&format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(4 * hidden_dim, hidden_dim),
&format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
Some(init) => Some(vb.get_with_hints(
4 * hidden_dim,
&format!("bias_ih_l{layer_idx}{direction_str}"),
init,
)?),
None => None,
};
let b_hh = match config.b_hh_init {
Some(init) => Some(vb.get_with_hints(
4 * hidden_dim,
&format!("bias_hh_l{layer_idx}{direction_str}"),
init,
)?),
None => None,
};
Ok(Self {
w_ih,
w_hh,
b_ih,
b_hh,
hidden_dim,
config,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}

pub fn config(&self) -> &LSTMConfig {
&self.config
}
}

/// Creates a LSTM layer.
pub fn lstm(
in_dim: usize,
hidden_dim: usize,
config: LSTMConfig,
vb: crate::VarBuilder,
) -> Result<LSTM> {
let layer_idx = config.layer_idx;
let direction_str = match config.direction {
Direction::Forward => "",
Direction::Backward => "_reverse",
};
let w_ih = vb.get_with_hints(
(4 * hidden_dim, in_dim),
&format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(4 * hidden_dim, hidden_dim),
&format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
Some(init) => Some(vb.get_with_hints(
4 * hidden_dim,
&format!("bias_ih_l{layer_idx}{direction_str}"),
init,
)?),
None => None,
};
let b_hh = match config.b_hh_init {
Some(init) => Some(vb.get_with_hints(
4 * hidden_dim,
&format!("bias_hh_l{layer_idx}{direction_str}"),
init,
)?),
None => None,
};
Ok(LSTM {
w_ih,
w_hh,
b_ih,
b_hh,
hidden_dim,
config,
device: vb.device().clone(),
dtype: vb.dtype(),
})
LSTM::new(in_dim, hidden_dim, config, vb)
}

impl RNN for LSTM {
Expand Down Expand Up @@ -270,7 +286,7 @@ impl GRUConfig {
/// A Gated Recurrent Unit (GRU) layer.
///
/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
#[allow(clippy::upper_case_acronyms, unused)]
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Debug)]
pub struct GRU {
w_ih: Tensor,
Expand All @@ -283,41 +299,56 @@ pub struct GRU {
dtype: DType,
}

/// Creates a GRU layer.
impl GRU {
/// Creates a GRU layer.
pub fn new(
in_dim: usize,
hidden_dim: usize,
config: GRUConfig,
vb: crate::VarBuilder,
) -> Result<Self> {
let w_ih = vb.get_with_hints(
(3 * hidden_dim, in_dim),
"weight_ih_l0", // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(3 * hidden_dim, hidden_dim),
"weight_hh_l0", // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?),
None => None,
};
let b_hh = match config.b_hh_init {
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?),
None => None,
};
Ok(Self {
w_ih,
w_hh,
b_ih,
b_hh,
hidden_dim,
config,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}

pub fn config(&self) -> &GRUConfig {
&self.config
}
}

pub fn gru(
in_dim: usize,
hidden_dim: usize,
config: GRUConfig,
vb: crate::VarBuilder,
) -> Result<GRU> {
let w_ih = vb.get_with_hints(
(3 * hidden_dim, in_dim),
"weight_ih_l0", // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(3 * hidden_dim, hidden_dim),
"weight_hh_l0", // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?),
None => None,
};
let b_hh = match config.b_hh_init {
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?),
None => None,
};
Ok(GRU {
w_ih,
w_hh,
b_ih,
b_hh,
hidden_dim,
config,
device: vb.device().clone(),
dtype: vb.dtype(),
})
GRU::new(in_dim, hidden_dim, config, vb)
}

impl RNN for GRU {
Expand Down

0 comments on commit d5cb827

Please sign in to comment.