Skip to content

Commit

Permalink
Example about RNN multi-layer and bidirection
Browse files Browse the repository at this point in the history
  • Loading branch information
kigichang committed Oct 8, 2024
1 parent 2a3bec7 commit 0e3c841
Show file tree
Hide file tree
Showing 2 changed files with 241 additions and 0 deletions.
15 changes: 15 additions & 0 deletions candle-examples/examples/rnn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# candle-rnn: Recurrent Neural Network

This example demonstrates how to use the `candle_nn::rnn` crate to run LSTM and GRU, including bidirection and multi-layer.

## Running the example

```bash
$ cargo run --example rnn --release
```

Choose LSTM or GRU via the `--model` argument, number of layers via `--layer`, and to enable bidirectional via `--bidirection`.

```bash
$ cargo run --example rnn --release -- --model lstm --layers 3 --bidirection
```
226 changes: 226 additions & 0 deletions candle-examples/examples/rnn/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use anyhow::Result;
use candle::{DType, Tensor};
use candle_nn::{rnn, LSTMConfig, RNN};
use clap::Parser;

#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
enum WhichModel {
#[value(name = "lstm")]
LSTM,
#[value(name = "gru")]
GRU,
}

#[derive(Debug, Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(long)]
cpu: bool,

#[arg(long, default_value_t = 10)]
input_dim: usize,

#[arg(long, default_value_t = 20)]
hidden_dim: usize,

#[arg(long, default_value_t = 1)]
layers: usize,

#[arg(long)]
bidirection: bool,

#[arg(long, default_value_t = 5)]
batch_size: usize,

#[arg(long, default_value_t = 3)]
seq_len: usize,

#[arg(long, default_value = "lstm")]
model: WhichModel,
}

fn lstm_config(layer_idx: usize, direction: rnn::Direction) -> LSTMConfig {
let mut config = LSTMConfig::default();
config.layer_idx = layer_idx;
config.direction = direction;
config
}

fn gru_config(layer_idx: usize, direction: rnn::Direction) -> rnn::GRUConfig {
let mut config = rnn::GRUConfig::default();
config.layer_idx = layer_idx;
config.direction = direction;
config
}

fn run_lstm(args: Args) -> Result<Tensor> {
let device = candle_examples::device(args.cpu)?;
let map = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device);

let mut layers = Vec::with_capacity(args.layers);

for layer_idx in 0..args.layers {
let input_dim = if layer_idx == 0 {
args.input_dim
} else {
args.hidden_dim
};
let config = lstm_config(layer_idx, rnn::Direction::Forward);
let lstm = candle_nn::lstm(input_dim, args.hidden_dim, config, vb.clone())?;
layers.push(lstm);
}

let mut input = Tensor::randn(
0.0_f32,
1.0,
(args.batch_size, args.seq_len, args.input_dim),
&device,
)?;

for layer in &layers {
let states = layer.seq(&input)?;
input = layer.states_to_tensor(&states)?;
}

Ok(input)
}

fn run_gru(args: Args) -> Result<Tensor> {
let device = candle_examples::device(args.cpu)?;
let map = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device);

let mut layers = Vec::with_capacity(args.layers);

for layer_idx in 0..args.layers {
let input_dim = if layer_idx == 0 {
args.input_dim
} else {
args.hidden_dim
};
let config = gru_config(layer_idx, rnn::Direction::Forward);
let gru = candle_nn::gru(input_dim, args.hidden_dim, config, vb.clone())?;
layers.push(gru);
}

let mut input = Tensor::randn(
0.0_f32,
1.0,
(args.batch_size, args.seq_len, args.input_dim),
&device,
)?;

for layer in &layers {
let states = layer.seq(&input)?;
input = layer.states_to_tensor(&states)?;
}

Ok(input)
}

fn run_bidirectional_lstm(args: Args) -> Result<Tensor> {
let device = candle_examples::device(args.cpu)?;
let map = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device);

let mut layers = Vec::with_capacity(args.layers);

for layer_idx in 0..args.layers {
let input_dim = if layer_idx == 0 {
args.input_dim
} else {
args.hidden_dim * 2
};

let forward_config = lstm_config(layer_idx, rnn::Direction::Forward);
let forward = candle_nn::lstm(input_dim, args.hidden_dim, forward_config, vb.clone())?;

let backward_config = lstm_config(layer_idx, rnn::Direction::Backward);
let backward = candle_nn::lstm(input_dim, args.hidden_dim, backward_config, vb.clone())?;

layers.push((forward, backward));
}

let mut input = Tensor::randn(
0.0_f32,
1.0,
(args.batch_size, args.seq_len, args.input_dim),
&device,
)?;

for (forward, backward) in &layers {
let forward_states = forward.seq(&input)?;
let backward_states = backward.seq(&input)?;
input = forward.combine_states_to_tensor(&forward_states, &backward_states)?;
}
Ok(input)
}

fn run_bidirectional_gru(args: Args) -> Result<Tensor> {
let device = candle_examples::device(args.cpu)?;
let map = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device);

let mut layers = Vec::with_capacity(args.layers);
for layer_idx in 0..args.layers {
let input_dim = if layer_idx == 0 {
args.input_dim
} else {
args.hidden_dim * 2
};

let forward_config = gru_config(layer_idx, rnn::Direction::Forward);
let forward = candle_nn::gru(input_dim, args.hidden_dim, forward_config, vb.clone())?;

let backward_config = gru_config(layer_idx, rnn::Direction::Backward);
let backward = candle_nn::gru(input_dim, args.hidden_dim, backward_config, vb.clone())?;

layers.push((forward, backward));
}

let mut input = Tensor::randn(
0.0_f32,
1.0,
(args.batch_size, args.seq_len, args.input_dim),
&device,
)?;

for (forward, backward) in &layers {
let forward_states = forward.seq(&input)?;
let backward_states = backward.seq(&input)?;
input = forward.combine_states_to_tensor(&forward_states, &backward_states)?;
}

Ok(input)
}

fn main() -> Result<()> {
let args = Args::parse();
let runs = if args.bidirection { 2 } else { 1 };
let batch_size = args.batch_size;
let seq_len = args.seq_len;
let hidden_dim = args.hidden_dim;

println!(
"Running {:?} bidirection: {} layers: {}",
args.model, args.bidirection, args.layers
);

let output = match (args.model, args.bidirection) {
(WhichModel::LSTM, false) => run_lstm(args),
(WhichModel::GRU, false) => run_gru(args),
(WhichModel::LSTM, true) => run_bidirectional_lstm(args),
(WhichModel::GRU, true) => run_bidirectional_gru(args),
}?;

assert_eq!(output.dims3()?, (batch_size, seq_len, hidden_dim * runs));

Ok(())
}

0 comments on commit 0e3c841

Please sign in to comment.