From 0e3c8410c913d84c76afda67cdbd8318b987dd8d Mon Sep 17 00:00:00 2001 From: kigi Date: Tue, 8 Oct 2024 21:22:04 +0800 Subject: [PATCH] Example about RNN multi-layer and bidirection --- candle-examples/examples/rnn/README.md | 15 ++ candle-examples/examples/rnn/main.rs | 226 +++++++++++++++++++++++++ 2 files changed, 241 insertions(+) create mode 100644 candle-examples/examples/rnn/README.md create mode 100644 candle-examples/examples/rnn/main.rs diff --git a/candle-examples/examples/rnn/README.md b/candle-examples/examples/rnn/README.md new file mode 100644 index 000000000..3f508bf00 --- /dev/null +++ b/candle-examples/examples/rnn/README.md @@ -0,0 +1,15 @@ +# candle-rnn: Recurrent Neural Network + +This example demonstrates how to use the `candle_nn::rnn` crate to run LSTM and GRU, including bidirection and multi-layer. + +## Running the example + +```bash +$ cargo run --example rnn --release +``` + +Choose LSTM or GRU via the `--model` argument, number of layers via `--layer`, and to enable bidirectional via `--bidirection`. + +```bash +$ cargo run --example rnn --release -- --model lstm --layers 3 --bidirection +``` diff --git a/candle-examples/examples/rnn/main.rs b/candle-examples/examples/rnn/main.rs new file mode 100644 index 000000000..7998ce893 --- /dev/null +++ b/candle-examples/examples/rnn/main.rs @@ -0,0 +1,226 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Result; +use candle::{DType, Tensor}; +use candle_nn::{rnn, LSTMConfig, RNN}; +use clap::Parser; + +#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)] +enum WhichModel { + #[value(name = "lstm")] + LSTM, + #[value(name = "gru")] + GRU, +} + +#[derive(Debug, Parser)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(long)] + cpu: bool, + + #[arg(long, default_value_t = 10)] + input_dim: usize, + + #[arg(long, default_value_t = 20)] + hidden_dim: usize, + + #[arg(long, default_value_t = 1)] + layers: usize, + + #[arg(long)] + bidirection: bool, + + #[arg(long, default_value_t = 5)] + batch_size: usize, + + #[arg(long, default_value_t = 3)] + seq_len: usize, + + #[arg(long, default_value = "lstm")] + model: WhichModel, +} + +fn lstm_config(layer_idx: usize, direction: rnn::Direction) -> LSTMConfig { + let mut config = LSTMConfig::default(); + config.layer_idx = layer_idx; + config.direction = direction; + config +} + +fn gru_config(layer_idx: usize, direction: rnn::Direction) -> rnn::GRUConfig { + let mut config = rnn::GRUConfig::default(); + config.layer_idx = layer_idx; + config.direction = direction; + config +} + +fn run_lstm(args: Args) -> Result { + let device = candle_examples::device(args.cpu)?; + let map = candle_nn::VarMap::new(); + let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + + let mut layers = Vec::with_capacity(args.layers); + + for layer_idx in 0..args.layers { + let input_dim = if layer_idx == 0 { + args.input_dim + } else { + args.hidden_dim + }; + let config = lstm_config(layer_idx, rnn::Direction::Forward); + let lstm = candle_nn::lstm(input_dim, args.hidden_dim, config, vb.clone())?; + layers.push(lstm); + } + + let mut input = Tensor::randn( + 0.0_f32, + 1.0, + (args.batch_size, args.seq_len, args.input_dim), + &device, + )?; + + for layer in &layers { + let states = layer.seq(&input)?; + input = layer.states_to_tensor(&states)?; + } + + Ok(input) +} + +fn run_gru(args: Args) -> Result { + let device = candle_examples::device(args.cpu)?; + let map = candle_nn::VarMap::new(); + let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + + let mut layers = Vec::with_capacity(args.layers); + + for layer_idx in 0..args.layers { + let input_dim = if layer_idx == 0 { + args.input_dim + } else { + args.hidden_dim + }; + let config = gru_config(layer_idx, rnn::Direction::Forward); + let gru = candle_nn::gru(input_dim, args.hidden_dim, config, vb.clone())?; + layers.push(gru); + } + + let mut input = Tensor::randn( + 0.0_f32, + 1.0, + (args.batch_size, args.seq_len, args.input_dim), + &device, + )?; + + for layer in &layers { + let states = layer.seq(&input)?; + input = layer.states_to_tensor(&states)?; + } + + Ok(input) +} + +fn run_bidirectional_lstm(args: Args) -> Result { + let device = candle_examples::device(args.cpu)?; + let map = candle_nn::VarMap::new(); + let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + + let mut layers = Vec::with_capacity(args.layers); + + for layer_idx in 0..args.layers { + let input_dim = if layer_idx == 0 { + args.input_dim + } else { + args.hidden_dim * 2 + }; + + let forward_config = lstm_config(layer_idx, rnn::Direction::Forward); + let forward = candle_nn::lstm(input_dim, args.hidden_dim, forward_config, vb.clone())?; + + let backward_config = lstm_config(layer_idx, rnn::Direction::Backward); + let backward = candle_nn::lstm(input_dim, args.hidden_dim, backward_config, vb.clone())?; + + layers.push((forward, backward)); + } + + let mut input = Tensor::randn( + 0.0_f32, + 1.0, + (args.batch_size, args.seq_len, args.input_dim), + &device, + )?; + + for (forward, backward) in &layers { + let forward_states = forward.seq(&input)?; + let backward_states = backward.seq(&input)?; + input = forward.combine_states_to_tensor(&forward_states, &backward_states)?; + } + Ok(input) +} + +fn run_bidirectional_gru(args: Args) -> Result { + let device = candle_examples::device(args.cpu)?; + let map = candle_nn::VarMap::new(); + let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + + let mut layers = Vec::with_capacity(args.layers); + for layer_idx in 0..args.layers { + let input_dim = if layer_idx == 0 { + args.input_dim + } else { + args.hidden_dim * 2 + }; + + let forward_config = gru_config(layer_idx, rnn::Direction::Forward); + let forward = candle_nn::gru(input_dim, args.hidden_dim, forward_config, vb.clone())?; + + let backward_config = gru_config(layer_idx, rnn::Direction::Backward); + let backward = candle_nn::gru(input_dim, args.hidden_dim, backward_config, vb.clone())?; + + layers.push((forward, backward)); + } + + let mut input = Tensor::randn( + 0.0_f32, + 1.0, + (args.batch_size, args.seq_len, args.input_dim), + &device, + )?; + + for (forward, backward) in &layers { + let forward_states = forward.seq(&input)?; + let backward_states = backward.seq(&input)?; + input = forward.combine_states_to_tensor(&forward_states, &backward_states)?; + } + + Ok(input) +} + +fn main() -> Result<()> { + let args = Args::parse(); + let runs = if args.bidirection { 2 } else { 1 }; + let batch_size = args.batch_size; + let seq_len = args.seq_len; + let hidden_dim = args.hidden_dim; + + println!( + "Running {:?} bidirection: {} layers: {}", + args.model, args.bidirection, args.layers + ); + + let output = match (args.model, args.bidirection) { + (WhichModel::LSTM, false) => run_lstm(args), + (WhichModel::GRU, false) => run_gru(args), + (WhichModel::LSTM, true) => run_bidirectional_lstm(args), + (WhichModel::GRU, true) => run_bidirectional_gru(args), + }?; + + assert_eq!(output.dims3()?, (batch_size, seq_len, hidden_dim * runs)); + + Ok(()) +}