Skip to content

Commit

Permalink
Add/lstm direction (#2455)
Browse files Browse the repository at this point in the history
* add: direction for lstm layer

* lint: remove unused Error import

* refactor: remove unnecessary int assignment to Direction enum:

* refactor: use &'static str type instead of String for direction_str:

* Run cargofmt.

---------

Co-authored-by: Laurent <[email protected]>
  • Loading branch information
singjc and LaurentMazare authored Sep 30, 2024
1 parent 7246504 commit aa35bf2
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions candle-nn/src/rnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ impl LSTMState {
}
}

#[derive(Debug, Clone, Copy)]
pub enum Direction {
Forward,
Backward,
}

#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Clone, Copy)]
pub struct LSTMConfig {
Expand All @@ -78,6 +84,7 @@ pub struct LSTMConfig {
pub b_ih_init: Option<super::Init>,
pub b_hh_init: Option<super::Init>,
pub layer_idx: usize,
pub direction: Direction,
}

impl Default for LSTMConfig {
Expand All @@ -88,6 +95,7 @@ impl Default for LSTMConfig {
b_ih_init: Some(super::Init::Const(0.)),
b_hh_init: Some(super::Init::Const(0.)),
layer_idx: 0,
direction: Direction::Forward,
}
}
}
Expand All @@ -100,6 +108,7 @@ impl LSTMConfig {
b_ih_init: None,
b_hh_init: None,
layer_idx: 0,
direction: Direction::Forward,
}
}
}
Expand Down Expand Up @@ -128,26 +137,34 @@ pub fn lstm(
vb: crate::VarBuilder,
) -> Result<LSTM> {
let layer_idx = config.layer_idx;
let direction_str = match config.direction {
Direction::Forward => "",
Direction::Backward => "_reverse",
};
let w_ih = vb.get_with_hints(
(4 * hidden_dim, in_dim),
&format!("weight_ih_l{layer_idx}"), // Only a single layer is supported.
&format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(4 * hidden_dim, hidden_dim),
&format!("weight_hh_l{layer_idx}"), // Only a single layer is supported.
&format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
Some(init) => {
Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_ih_l{layer_idx}"), init)?)
}
Some(init) => Some(vb.get_with_hints(
4 * hidden_dim,
&format!("bias_ih_l{layer_idx}{direction_str}"),
init,
)?),
None => None,
};
let b_hh = match config.b_hh_init {
Some(init) => {
Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_hh_l{layer_idx}"), init)?)
}
Some(init) => Some(vb.get_with_hints(
4 * hidden_dim,
&format!("bias_hh_l{layer_idx}{direction_str}"),
init,
)?),
None => None,
};
Ok(LSTM {
Expand Down

0 comments on commit aa35bf2

Please sign in to comment.