From 54d881943a62f0ecdc785e74531cd2e5c42586a4 Mon Sep 17 00:00:00 2001 From: kigi Date: Fri, 11 Oct 2024 16:41:32 +0800 Subject: [PATCH] add test for example to confirm that the results are similar to pytorch --- candle-examples/examples/rnn/README.md | 152 +++++++++++ candle-examples/examples/rnn/main.rs | 341 +++++++++++++++++++------ 2 files changed, 410 insertions(+), 83 deletions(-) diff --git a/candle-examples/examples/rnn/README.md b/candle-examples/examples/rnn/README.md index 3f508bf00..567c9a96b 100644 --- a/candle-examples/examples/rnn/README.md +++ b/candle-examples/examples/rnn/README.md @@ -13,3 +13,155 @@ Choose LSTM or GRU via the `--model` argument, number of layers via `--layer`, a ```bash $ cargo run --example rnn --release -- --model lstm --layers 3 --bidirection ``` + +## Running the example test + +Add argument `--test` to run test of this example. + +```bash +$ cargo run --example rnn --release -- --test +``` + +Test models are generated by Pytorch [LSTM](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html) and [GRU](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html). These models include input and output tensors and can be downloaded from [here](https://huggingface.co/kigichang/test_rnn). + +Test models are generated by the following codes: + +- lstm_test.pt: A simple LSTM model with 1 layer. + + ```python + import torch + import torch.nn as nn + + rnn = nn.LSTM(10, 20, num_layers=1, batch_first=True) + input = torch.randn(5, 3, 10) + output, (hn, cn) = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + state_dict['cn'] = cn + torch.save(state_dict, "lstm_test.pt") + ``` + +- gru_test.pt: A simple GRU model with 1 layer. + + ```python + import torch + import torch.nn as nn + + rnn = nn.GRU(10, 20, num_layers=1, batch_first=True) + input = torch.randn(5, 3, 10) + output, hn = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + torch.save(state_dict, "gru_test.pt") + ``` + +- bi_lstm_test.pt: A bidirectional LSTM model with 1 layer. + + ```python + import torch + import torch.nn as nn + + rnn = nn.LSTM(10, 20, num_layers=1, bidirectional=True, batch_first=True) + input = torch.randn(5, 3, 10) + output, (hn, cn) = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + state_dict['cn'] = cn + torch.save(state_dict, "bi_lstm_test.pt") + ``` + +- bi_gru_test.pt: A bidirectional GRU model with 1 layer. + + ```python + import torch + import torch.nn as nn + + rnn = nn.GRU(10, 20, num_layers=1, bidirectional=True, batch_first=True) + input = torch.randn(5, 3, 10) + output, hn = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + torch.save(state_dict, "bi_gru_test.pt") + ``` + +- lstm_nlayer_test.pt: A LSTM model with 3 layers. + + ```python + import torch + import torch.nn as nn + + rnn = nn.LSTM(10, 20, num_layers=3, batch_first=True) + input = torch.randn(5, 3, 10) + output, (hn, cn) = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + state_dict['cn'] = cn + torch.save(state_dict, "lstm_nlayer_test.pt") + ``` + +- bi_lstm_nlayer_test.pt: A bidirectional LSTM model with 3 layers. + + ```python + import torch + import torch.nn as nn + + rnn = nn.LSTM(10, 20, num_layers=3, bidirectional=True, batch_first=True) + input = torch.randn(5, 3, 10) + output, (hn, cn) = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + state_dict['cn'] = cn + torch.save(state_dict, "bi_lstm_nlayer_test.pt") + ``` + +- gru_nlayer_test.pt: A GRU model with 3 layers. + + ```python + import torch + import torch.nn as nn + + rnn = nn.GRU(10, 20, num_layers=3, batch_first=True) + input = torch.randn(5, 3, 10) + output, hn = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + torch.save(state_dict, "gru_nlayer_test.pt") + ``` + +- bi_gru_nlayer_test.pt: A bidirectional GRU model with 3 layers. + + ```python + import torch + import torch.nn as nn + + rnn = nn.GRU(10, 20, num_layers=3, bidirectional=True, batch_first=True) + input = torch.randn(5, 3, 10) + output, hn = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + torch.save(state_dict, "bi_gru_nlayer_test.pt") + ``` diff --git a/candle-examples/examples/rnn/main.rs b/candle-examples/examples/rnn/main.rs index 7998ce893..83364a8ad 100644 --- a/candle-examples/examples/rnn/main.rs +++ b/candle-examples/examples/rnn/main.rs @@ -5,9 +5,12 @@ extern crate intel_mkl_src; extern crate accelerate_src; use anyhow::Result; -use candle::{DType, Tensor}; -use candle_nn::{rnn, LSTMConfig, RNN}; +use candle::{DType, Device, Tensor, D}; +use candle_nn::{rnn, LSTMConfig, VarBuilder, RNN}; use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; + +const ACCURACY: f32 = 1e-6; #[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)] enum WhichModel { @@ -17,7 +20,7 @@ enum WhichModel { GRU, } -#[derive(Debug, Parser)] +#[derive(Clone, Copy, Debug, Parser)] #[command(author, version, about, long_about = None)] struct Args { #[arg(long)] @@ -43,6 +46,102 @@ struct Args { #[arg(long, default_value = "lstm")] model: WhichModel, + + #[arg(long)] + test: bool, +} + +impl Args { + pub fn load_model(&self) -> Result<(Config, VarBuilder<'static>, Tensor)> { + let device = self.device()?; + if self.test { + // run unit test and download model from huggingface hub. + let model = match self.model { + WhichModel::LSTM => "lstm", + WhichModel::GRU => "gru", + }; + + let bidirection = if self.bidirection { "bi_" } else { "" }; + let layer = if self.layers > 1 { "_nlayer" } else { "" }; + let model = format!("{}{}{}_test", bidirection, model, layer); + let (config, vb) = load_model(&model, &device)?; + let input = vb.get( + (config.batch_size, config.sequence_length, config.input), + "input", + )?; + Ok((config, vb, input)) + } else { + let map = candle_nn::VarMap::new(); + let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + let input = Tensor::randn( + 0.0_f32, + 1.0, + (self.batch_size, self.seq_len, self.input_dim), + &device, + )?; + Ok((self.into(), vb, input)) + } + } + + pub fn device(&self) -> Result { + Ok(candle_examples::device(self.cpu)?) + } +} + +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +struct Config { + pub input: usize, + pub batch_size: usize, + pub sequence_length: usize, + pub hidden: usize, + pub layers: usize, + pub bidirection: bool, +} + +impl From<&Args> for Config { + fn from(args: &Args) -> Self { + Config { + input: args.input_dim, + batch_size: args.batch_size, + sequence_length: args.seq_len, + hidden: args.hidden_dim, + layers: args.layers, + bidirection: args.bidirection, + } + } +} + +fn load_model(model: &str, device: &Device) -> Result<(Config, VarBuilder<'static>)> { + let api = Api::new()?; + let repo_id = "kigichang/test_rnn".to_string(); + let repo = api.repo(Repo::with_revision( + repo_id, + RepoType::Model, + "main".to_string(), + )); + + let filename = repo.get(&format!("{}.pt", model))?; + let config_file = repo.get(&format!("{}.json", model))?; + + let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?; + let vb = VarBuilder::from_pth(filename, DType::F32, device)?; + + Ok((config, vb)) +} + +fn assert_tensor(a: &Tensor, b: &Tensor, v: f32) -> Result<()> { + assert_eq!(a.dims(), b.dims()); + let dim = a.dims().len(); + let mut t = (a - b)?.abs()?; + + for _i in 0..dim { + t = t.max(D::Minus1)?; + } + + let t = t.to_scalar::()?; + println!("max diff = {}", t); + assert!(t < v); + Ok(()) } fn lstm_config(layer_idx: usize, direction: rnn::Direction) -> LSTMConfig { @@ -60,142 +159,139 @@ fn gru_config(layer_idx: usize, direction: rnn::Direction) -> rnn::GRUConfig { } fn run_lstm(args: Args) -> Result { - let device = candle_examples::device(args.cpu)?; - let map = candle_nn::VarMap::new(); - let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + let (config, vb, mut input) = args.load_model()?; - let mut layers = Vec::with_capacity(args.layers); + let mut layers = Vec::with_capacity(config.layers); - for layer_idx in 0..args.layers { + for layer_idx in 0..config.layers { let input_dim = if layer_idx == 0 { - args.input_dim + config.input } else { - args.hidden_dim + config.hidden }; - let config = lstm_config(layer_idx, rnn::Direction::Forward); - let lstm = candle_nn::lstm(input_dim, args.hidden_dim, config, vb.clone())?; + let lstm_config = lstm_config(layer_idx, rnn::Direction::Forward); + let lstm = candle_nn::lstm(input_dim, config.hidden, lstm_config, vb.clone())?; layers.push(lstm); } - let mut input = Tensor::randn( - 0.0_f32, - 1.0, - (args.batch_size, args.seq_len, args.input_dim), - &device, - )?; - for layer in &layers { let states = layer.seq(&input)?; input = layer.states_to_tensor(&states)?; } + if args.test { + let answer = vb.get( + (config.batch_size, config.sequence_length, config.hidden), + "output", + )?; + assert_tensor(&input, &answer, ACCURACY)?; + } + Ok(input) } fn run_gru(args: Args) -> Result { - let device = candle_examples::device(args.cpu)?; - let map = candle_nn::VarMap::new(); - let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + let (config, vb, mut input) = args.load_model()?; - let mut layers = Vec::with_capacity(args.layers); + let mut layers = Vec::with_capacity(config.layers); - for layer_idx in 0..args.layers { + for layer_idx in 0..config.layers { let input_dim = if layer_idx == 0 { - args.input_dim + config.input } else { - args.hidden_dim + config.hidden }; - let config = gru_config(layer_idx, rnn::Direction::Forward); - let gru = candle_nn::gru(input_dim, args.hidden_dim, config, vb.clone())?; + let gru_config = gru_config(layer_idx, rnn::Direction::Forward); + let gru = candle_nn::gru(input_dim, config.hidden, gru_config, vb.clone())?; layers.push(gru); } - let mut input = Tensor::randn( - 0.0_f32, - 1.0, - (args.batch_size, args.seq_len, args.input_dim), - &device, - )?; - for layer in &layers { let states = layer.seq(&input)?; input = layer.states_to_tensor(&states)?; } + if args.test { + let answer = vb.get( + (config.batch_size, config.sequence_length, config.hidden), + "output", + )?; + assert_tensor(&input, &answer, ACCURACY)?; + } + Ok(input) } fn run_bidirectional_lstm(args: Args) -> Result { - let device = candle_examples::device(args.cpu)?; - let map = candle_nn::VarMap::new(); - let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + let (config, vb, mut input) = args.load_model()?; - let mut layers = Vec::with_capacity(args.layers); + let mut layers = Vec::with_capacity(config.layers); - for layer_idx in 0..args.layers { + for layer_idx in 0..config.layers { let input_dim = if layer_idx == 0 { - args.input_dim + config.input } else { - args.hidden_dim * 2 + config.hidden * 2 }; let forward_config = lstm_config(layer_idx, rnn::Direction::Forward); - let forward = candle_nn::lstm(input_dim, args.hidden_dim, forward_config, vb.clone())?; + let forward = candle_nn::lstm(input_dim, config.hidden, forward_config, vb.clone())?; let backward_config = lstm_config(layer_idx, rnn::Direction::Backward); - let backward = candle_nn::lstm(input_dim, args.hidden_dim, backward_config, vb.clone())?; + let backward = candle_nn::lstm(input_dim, config.hidden, backward_config, vb.clone())?; layers.push((forward, backward)); } - let mut input = Tensor::randn( - 0.0_f32, - 1.0, - (args.batch_size, args.seq_len, args.input_dim), - &device, - )?; - for (forward, backward) in &layers { let forward_states = forward.seq(&input)?; let backward_states = backward.seq(&input)?; - input = forward.combine_states_to_tensor(&forward_states, &backward_states)?; + input = forward.bidirectional_states_to_tensor(&forward_states, &backward_states)?; + } + + if args.test { + let answer = vb.get( + (config.batch_size, config.sequence_length, config.hidden * 2), + "output", + )?; + assert_tensor(&input, &answer, ACCURACY)?; } + Ok(input) } fn run_bidirectional_gru(args: Args) -> Result { - let device = candle_examples::device(args.cpu)?; - let map = candle_nn::VarMap::new(); - let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + let (config, vb, mut input) = args.load_model()?; - let mut layers = Vec::with_capacity(args.layers); - for layer_idx in 0..args.layers { + let mut layers = Vec::with_capacity(config.layers); + for layer_idx in 0..config.layers { let input_dim = if layer_idx == 0 { - args.input_dim + config.input } else { - args.hidden_dim * 2 + config.hidden * 2 }; let forward_config = gru_config(layer_idx, rnn::Direction::Forward); - let forward = candle_nn::gru(input_dim, args.hidden_dim, forward_config, vb.clone())?; + let forward = candle_nn::gru(input_dim, config.hidden, forward_config, vb.clone())?; let backward_config = gru_config(layer_idx, rnn::Direction::Backward); - let backward = candle_nn::gru(input_dim, args.hidden_dim, backward_config, vb.clone())?; + let backward = candle_nn::gru(input_dim, config.hidden, backward_config, vb.clone())?; layers.push((forward, backward)); } - let mut input = Tensor::randn( - 0.0_f32, - 1.0, - (args.batch_size, args.seq_len, args.input_dim), - &device, - )?; - for (forward, backward) in &layers { let forward_states = forward.seq(&input)?; let backward_states = backward.seq(&input)?; - input = forward.combine_states_to_tensor(&forward_states, &backward_states)?; + input = forward.bidirectional_states_to_tensor(&forward_states, &backward_states)?; + } + + if args.test { + let answer = vb.get( + (config.batch_size, config.sequence_length, config.hidden * 2), + "output", + )?; + assert_tensor(&input, &answer, ACCURACY)?; } Ok(input) @@ -203,24 +299,103 @@ fn run_bidirectional_gru(args: Args) -> Result { fn main() -> Result<()> { let args = Args::parse(); - let runs = if args.bidirection { 2 } else { 1 }; - let batch_size = args.batch_size; - let seq_len = args.seq_len; - let hidden_dim = args.hidden_dim; println!( - "Running {:?} bidirection: {} layers: {}", - args.model, args.bidirection, args.layers + "Running {:?} bidirection: {} layers: {} example-test: {}", + args.model, args.bidirection, args.layers, args.test ); - let output = match (args.model, args.bidirection) { - (WhichModel::LSTM, false) => run_lstm(args), - (WhichModel::GRU, false) => run_gru(args), - (WhichModel::LSTM, true) => run_bidirectional_lstm(args), - (WhichModel::GRU, true) => run_bidirectional_gru(args), - }?; - - assert_eq!(output.dims3()?, (batch_size, seq_len, hidden_dim * runs)); + if args.test { + let test_args = Args { + model: WhichModel::LSTM, + bidirection: false, + layers: 1, + ..args + }; + print!("Testing LSTM with 1 layer: "); + run_lstm(test_args)?; + + let test_args = Args { + model: WhichModel::GRU, + bidirection: false, + layers: 1, + ..args + }; + print!("Testing GRU with 1 layer: "); + run_gru(test_args)?; + + let test_args = Args { + model: WhichModel::LSTM, + bidirection: true, + layers: 1, + ..args + }; + print!("Testing bidirectional LSTM with 1 layer: "); + run_bidirectional_lstm(test_args)?; + + let test_args = Args { + model: WhichModel::GRU, + bidirection: true, + layers: 1, + ..args + }; + print!("Testing bidirectional GRU with 1 layer: "); + run_bidirectional_gru(test_args)?; + + let test_args = Args { + model: WhichModel::LSTM, + bidirection: false, + layers: 3, + ..args + }; + print!("Testing LSTM with 3 layers: "); + run_lstm(test_args)?; + + let test_args = Args { + model: WhichModel::GRU, + bidirection: false, + layers: 3, + ..args + }; + print!("Testing GRU with 3 layers: "); + run_gru(test_args)?; + + let test_args = Args { + model: WhichModel::LSTM, + bidirection: true, + layers: 3, + ..args + }; + print!("Testing bidirectional LSTM with 3 layers: "); + run_bidirectional_lstm(test_args)?; + + let test_args = Args { + model: WhichModel::GRU, + bidirection: true, + layers: 3, + ..args + }; + print!("Testing bidirectional GRU with 3 layers: "); + run_bidirectional_gru(test_args)?; + } else { + let num_directions = if args.bidirection { 2 } else { 1 }; + let batch_size = args.batch_size; + let seq_len = args.seq_len; + let hidden_dim = args.hidden_dim; + + let output = match (args.model, args.bidirection) { + (WhichModel::LSTM, false) => run_lstm(args), + (WhichModel::GRU, false) => run_gru(args), + (WhichModel::LSTM, true) => run_bidirectional_lstm(args), + (WhichModel::GRU, true) => run_bidirectional_gru(args), + }?; + + assert_eq!( + output.dims3()?, + (batch_size, seq_len, hidden_dim * num_directions) + ); + println!("result dims: {:?}", output.dims()); + } Ok(()) }