diff --git a/Cargo.lock b/Cargo.lock index ab9eddb0bd..c9ce522699 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3855,6 +3855,17 @@ dependencies = [ "burn-import", ] +[[package]] +name = "modern-lstm" +version = "0.1.0" +dependencies = [ + "burn", + "polars", + "rand", + "rand_distr", + "serde", +] + [[package]] name = "monostate" version = "0.1.13" diff --git a/examples/modern-lstm/Cargo.toml b/examples/modern-lstm/Cargo.toml new file mode 100644 index 0000000000..86855e9ad4 --- /dev/null +++ b/examples/modern-lstm/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "modern-lstm" +version = "0.1.0" +edition = "2021" + +[features] +ndarray = ["burn/ndarray"] +ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] +ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] +ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] +tch-cpu = ["burn/tch"] +tch-gpu = ["burn/tch"] +wgpu = ["burn/wgpu"] +cuda = ["burn/cuda"] + +[dependencies] +burn = { path = "../../crates/burn", features=["train"] } + +# Random number generator +rand = { workspace = true } +rand_distr = { workspace = true } + +# Serialization +serde = {workspace = true, features = ["std", "derive"]} + +# Organise the results in dataframe +polars = { workspace = true } diff --git a/examples/modern-lstm/README.md b/examples/modern-lstm/README.md new file mode 100644 index 0000000000..832851a1f0 --- /dev/null +++ b/examples/modern-lstm/README.md @@ -0,0 +1,46 @@ +# Advanced LSTM Implementation with Burn + +A more advanced implementation of Long Short-Term Memory (LSTM) networks in Burn with combined +weight matrices for the input and hidden states, based on the +[PyTorch implementation](https://github.com/shiv08/Advanced-LSTM-Implementation-with-PyTorch). + +`LstmNetwork` is the top-level module with bidirectional and regularization support. The LSTM +variants differ by `bidirectional` and `num_layers` settings: + +- LSTM: `num_layers = 1` and `bidirectional = false` +- Stacked LSTM: `num_layers > 1` and `bidirectional = false` +- Bidirectional LSTM: `num_layers = 1` and `bidirectional = true` +- Bidirectional Stacked LSTM: `num_layers > 1` and `bidirectional = true` + +This implementation is complementary to Burn's official LSTM, users can choose either one depends on +the project's specific needs. + +## Usage + +## Training + +```sh +# Cuda backend +cargo run --example lstm-train --release --features cuda-jit + +# Wgpu backend +cargo run --example lstm-train --release --features wgpu + +# Tch GPU backend +export TORCH_CUDA_VERSION=cu121 # Set the cuda version +cargo run --example lstm-train --release --features tch-gpu + +# Tch CPU backend +cargo run --example lstm-train --release --features tch-cpu + +# NdArray backend (CPU) +cargo run --example lstm-train --release --features ndarray +cargo run --example lstm-train --release --features ndarray-blas-openblas +cargo run --example lstm-train --release --features ndarray-blas-netlib +``` + +### Inference + +```sh +cargo run --example lstm-infer --release --features cuda-jit +``` diff --git a/examples/modern-lstm/examples/lstm-infer.rs b/examples/modern-lstm/examples/lstm-infer.rs new file mode 100644 index 0000000000..f601d08c79 --- /dev/null +++ b/examples/modern-lstm/examples/lstm-infer.rs @@ -0,0 +1,86 @@ +use burn::tensor::backend::Backend; + +pub fn launch(device: B::Device) { + modern_lstm::inference::infer::("/tmp/modern-lstm", device); +} + +#[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", +))] +mod ndarray { + use burn::backend::ndarray::{NdArray, NdArrayDevice}; + + use crate::launch; + + pub fn run() { + launch::(NdArrayDevice::Cpu); + } +} + +#[cfg(feature = "tch-gpu")] +mod tch_gpu { + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + + use crate::launch; + + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + + launch::(device); + } +} + +#[cfg(feature = "tch-cpu")] +mod tch_cpu { + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + + use crate::launch; + + pub fn run() { + launch::(LibTorchDevice::Cpu); + } +} + +#[cfg(feature = "wgpu")] +mod wgpu { + use crate::launch; + use burn::backend::wgpu::Wgpu; + + pub fn run() { + launch::(Default::default()); + } +} + +#[cfg(feature = "cuda")] +mod cuda { + use crate::launch; + use burn::backend::Cuda; + + pub fn run() { + launch::(Default::default()); + } +} + +fn main() { + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); + #[cfg(feature = "cuda")] + cuda::run(); +} diff --git a/examples/modern-lstm/examples/lstm-train.rs b/examples/modern-lstm/examples/lstm-train.rs new file mode 100644 index 0000000000..454263d331 --- /dev/null +++ b/examples/modern-lstm/examples/lstm-train.rs @@ -0,0 +1,104 @@ +use burn::{ + grad_clipping::GradientClippingConfig, optim::AdamConfig, tensor::backend::AutodiffBackend, +}; +use modern_lstm::{model::LstmNetworkConfig, training::TrainingConfig}; + +pub fn launch(device: B::Device) { + let config = TrainingConfig::new( + LstmNetworkConfig::new(), + // Gradient clipping via optimizer config + AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(1.0))), + ); + + modern_lstm::training::train::("/tmp/modern-lstm", config, device); +} + +#[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", +))] +mod ndarray { + use burn::backend::{ + ndarray::{NdArray, NdArrayDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(NdArrayDevice::Cpu); + } +} + +#[cfg(feature = "tch-gpu")] +mod tch_gpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + + launch::>(device); + } +} + +#[cfg(feature = "tch-cpu")] +mod tch_cpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(LibTorchDevice::Cpu); + } +} + +#[cfg(feature = "wgpu")] +mod wgpu { + use crate::launch; + use burn::backend::{wgpu::Wgpu, Autodiff}; + + pub fn run() { + launch::>(Default::default()); + } +} + +#[cfg(feature = "cuda")] +mod cuda { + use crate::launch; + use burn::backend::{cuda::CudaDevice, Autodiff, Cuda}; + + pub fn run() { + launch::>(CudaDevice::default()); + } +} + +fn main() { + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); + #[cfg(feature = "cuda")] + cuda::run(); +} diff --git a/examples/modern-lstm/src/dataset.rs b/examples/modern-lstm/src/dataset.rs new file mode 100644 index 0000000000..b2d04d525f --- /dev/null +++ b/examples/modern-lstm/src/dataset.rs @@ -0,0 +1,110 @@ +use burn::{ + data::{ + dataloader::batcher::Batcher, + dataset::{Dataset, InMemDataset}, + }, + prelude::*, +}; +use rand::Rng; +use rand_distr::{Distribution, Normal}; +use serde::{Deserialize, Serialize}; + +// Dataset parameters +pub const NUM_SEQUENCES: usize = 1000; +pub const SEQ_LENGTH: usize = 10; +pub const NOISE_LEVEL: f32 = 0.1; +pub const RANDOM_SEED: u64 = 5; + +// Generate a sequence where each number is the sum of previous two numbers plus noise +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SequenceDatasetItem { + pub sequence: Vec, + pub target: f32, +} + +impl SequenceDatasetItem { + pub fn new(seq_length: usize, noise_level: f32) -> Self { + // Start with two random numbers between 0 and 1 + let mut seq = vec![rand::thread_rng().gen(), rand::thread_rng().gen()]; + + // Generate sequence + for _i in 0..seq_length { + // Next number is sum of previous two plus noise + let normal = Normal::new(0.0, noise_level).unwrap(); + let next_val = + seq[seq.len() - 2] + seq[seq.len() - 1] + normal.sample(&mut rand::thread_rng()); + seq.push(next_val); + } + + Self { + // Convert to sequence and target + sequence: seq[0..seq.len() - 1].to_vec(), // All but last + target: seq[seq.len() - 1], // Last value + } + } +} + +// Custom Dataset for Sequence Data +pub struct SequenceDataset { + dataset: InMemDataset, +} + +impl SequenceDataset { + pub fn new(num_sequences: usize, seq_length: usize, noise_level: f32) -> Self { + let mut items = vec![]; + for _i in 0..num_sequences { + items.push(SequenceDatasetItem::new(seq_length, noise_level)); + } + let dataset = InMemDataset::new(items); + + Self { dataset } + } +} + +impl Dataset for SequenceDataset { + fn get(&self, index: usize) -> Option { + self.dataset.get(index) + } + + fn len(&self) -> usize { + self.dataset.len() + } +} + +#[derive(Clone, Debug)] +pub struct SequenceBatcher { + device: B::Device, +} + +#[derive(Clone, Debug)] +pub struct SequenceBatch { + pub sequences: Tensor, // [batch_size, seq_length, input_size] + pub targets: Tensor, // [batch_size, 1] +} + +impl SequenceBatcher { + pub fn new(device: B::Device) -> Self { + Self { device } + } +} + +impl Batcher> for SequenceBatcher { + fn batch(&self, items: Vec) -> SequenceBatch { + let mut sequences: Vec> = Vec::new(); + + for item in items.iter() { + let seq_tensor = Tensor::::from_floats(item.sequence.as_slice(), &self.device); + // Add feature dimension, the input_size is 1 implicitly. We can change the input_size here with some operations + sequences.push(seq_tensor.unsqueeze_dims(&[-1])); + } + let sequences = Tensor::stack(sequences, 0); + + let targets = items + .iter() + .map(|item| Tensor::::from_floats([item.target], &self.device)) + .collect(); + let targets = Tensor::stack(targets, 0); + + SequenceBatch { sequences, targets } + } +} diff --git a/examples/modern-lstm/src/inference.rs b/examples/modern-lstm/src/inference.rs new file mode 100644 index 0000000000..bad0af2996 --- /dev/null +++ b/examples/modern-lstm/src/inference.rs @@ -0,0 +1,45 @@ +use crate::{ + dataset::{ + SequenceBatcher, SequenceDataset, SequenceDatasetItem, NOISE_LEVEL, NUM_SEQUENCES, + SEQ_LENGTH, + }, + model::LstmNetwork, + training::TrainingConfig, +}; +use burn::{ + data::{dataloader::batcher::Batcher, dataset::Dataset}, + prelude::*, + record::{CompactRecorder, Recorder}, +}; +use polars::prelude::*; + +pub fn infer(artifact_dir: &str, device: B::Device) { + // Loading model + let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) + .expect("Config should exist for the model; run train first"); + let record = CompactRecorder::new() + .load(format!("{artifact_dir}/model").into(), &device) + .expect("Trained model should exist; run train first"); + + let model: LstmNetwork = config.model.init(&device).load_record(record); + + let dataset = SequenceDataset::new(NUM_SEQUENCES / 5, SEQ_LENGTH, NOISE_LEVEL); + let items: Vec = dataset.iter().collect(); + + let batcher = SequenceBatcher::new(device); + // Put all items in one batch + let batch = batcher.batch(items); + let predicted = model.forward(batch.sequences, None); + let targets = batch.targets; + + let predicted = predicted.squeeze::<1>(1).into_data(); + let expected = targets.squeeze::<1>(1).into_data(); + + // Display the predicted vs expected values + let results = df![ + "predicted" => &predicted.to_vec::().unwrap(), + "expected" => &expected.to_vec::().unwrap(), + ] + .unwrap(); + println!("{}", &results.head(Some(10))); +} diff --git a/examples/modern-lstm/src/lib.rs b/examples/modern-lstm/src/lib.rs new file mode 100644 index 0000000000..1a167ffd75 --- /dev/null +++ b/examples/modern-lstm/src/lib.rs @@ -0,0 +1,4 @@ +pub mod dataset; +pub mod inference; +pub mod model; +pub mod training; diff --git a/examples/modern-lstm/src/model.rs b/examples/modern-lstm/src/model.rs new file mode 100644 index 0000000000..268de59a0b --- /dev/null +++ b/examples/modern-lstm/src/model.rs @@ -0,0 +1,362 @@ +use burn::{ + nn::{ + Dropout, DropoutConfig, Initializer, LayerNorm, LayerNormConfig, Linear, LinearConfig, + LstmState, Sigmoid, Tanh, + }, + prelude::*, +}; + +/// LSTM Cell implementation with layer normalization. +/// +/// Mathematical formulation of LSTM: +/// f_t = σ(W_f · [h_{t-1}, x_t] + b_f) # Forget gate +/// i_t = σ(W_i · [h_{t-1}, x_t] + b_i] # Input gate +/// g_t = tanh(W_g · [h_{t-1}, x_t] + b_g] # Candidate cell state +/// o_t = σ(W_o · [h_{t-1}, x_t] + b_o) # Output gate +/// +/// c_t = f_t ⊙ c_{t-1} + i_t ⊙ g_t # New cell state +/// h_t = o_t ⊙ tanh(c_t) # New hidden state +/// +/// where: +/// - σ is the sigmoid function +/// - ⊙ is the element-wise multiplication +/// - [h_{t-1}, x_t] represents concatenation + +#[derive(Module, Debug)] +pub struct LstmCell { + pub hidden_size: usize, + // Combined weight matrices for efficiency + // weight_ih layer uses combined weights for [i_t, f_t, g_t, o_t] for input x_t + // weight_hh layer uses combined weights for [i_t, f_t, g_t, o_t] for hidden state h_{t-1} + pub weight_ih: Linear, + pub weight_hh: Linear, + // Layer Normalization for better training stability. Don't use BatchNorm because the input distribution is always changing for LSTM. + pub norm_x: LayerNorm, // Normalize gate pre-activations + pub norm_h: LayerNorm, // Normalize hidden state + pub norm_c: LayerNorm, // Normalize cell state + pub dropout: Dropout, +} + +/// Configuration to create a Lstm module using the init function. +#[derive(Config, Debug)] +pub struct LstmCellConfig { + // The size of the input features + pub input_size: usize, + // The size of the hidden state + pub hidden_size: usize, + // The number of hidden layers + pub dropout: f64, +} + +impl LstmCellConfig { + // Initialize parameters using best practices: + // 1. Orthogonal initialization for better gradient flow (here we use Xavier because of the lack of Orthogonal in burn) + // 2. Initialize forget gate bias to 1.0 to prevent forgetting at start of training + #[allow(clippy::single_range_in_vec_init)] + pub fn init(&self, device: &B::Device) -> LstmCell { + let initializer = Initializer::XavierNormal { gain: 1.0 }; + let init_bias = Tensor::::ones([self.hidden_size], device); + + let mut weight_ih = LinearConfig::new(self.input_size, 4 * self.hidden_size) + .with_initializer(initializer.clone()) + .init(device); + // Set forget gate bias to 1.0 (helps with learning long sequences) + let bias = weight_ih + .bias + .clone() + .unwrap() + .val() + .slice_assign([self.hidden_size..2 * self.hidden_size], init_bias.clone()); + weight_ih.bias = weight_ih.bias.map(|p| p.map(|_t| bias)); + + let mut weight_hh = LinearConfig::new(self.hidden_size, 4 * self.hidden_size) + .with_initializer(initializer) + .init(device); + let bias = weight_hh + .bias + .clone() + .unwrap() + .val() + .slice_assign([self.hidden_size..2 * self.hidden_size], init_bias); + weight_hh.bias = weight_hh.bias.map(|p| p.map(|_t| bias)); + + LstmCell { + hidden_size: self.hidden_size, + weight_ih, + weight_hh, + norm_x: LayerNormConfig::new(4 * self.hidden_size).init(device), + norm_h: LayerNormConfig::new(self.hidden_size).init(device), + norm_c: LayerNormConfig::new(self.hidden_size).init(device), + dropout: DropoutConfig::new(self.dropout).init(), + } + } +} + +impl LstmCell { + /// Forward pass of LSTM cell. + /// Args: + /// x: Input tensor of shape (batch_size, input_size) + /// state: Tuple of (h_{t-1}, c_{t-1}) each of shape (batch_size, hidden_size) + /// Returns: + /// Tuple of (h_t, c_t) representing new hidden and cell states + pub fn forward(&self, x: Tensor, state: LstmState) -> LstmState { + let (h_prev, c_prev) = (state.hidden, state.cell); + + // Combined matrix multiplication for all gates + // Shape: (batch_size, 4 * hidden_size) + let gates_x = self.weight_ih.forward(x); // Transform input + let gates_h = self.weight_hh.forward(h_prev); // Transform previous hidden state + + // Apply layer normalization + let gates_x = self.norm_x.forward(gates_x); + // Combined gate pre-activations + let gates = gates_x + gates_h; + + // Split into individual gates + // Each gate shape: (batch_size, hidden_size) + let gates = gates.chunk(4, 1); + let i_gate = gates[0].clone(); + let f_gate = gates[1].clone(); + let g_gate = gates[2].clone(); + let o_gate = gates[3].clone(); + + // Apply gate non-linearities + let i_t = Sigmoid::new().forward(i_gate); + let f_t = Sigmoid::new().forward(f_gate); + let g_t = Tanh::new().forward(g_gate); + let o_t = Sigmoid::new().forward(o_gate); + + // Update cell state: c_t = f_t ⊙ c_{t-1} + i_t ⊙ g_t + let c_t = f_t * c_prev + i_t * g_t; + let c_t = self.norm_c.forward(c_t); + + // Update cell state: h_t = o_t ⊙ tanh(c_t) + let h_t = o_t * Tanh::new().forward(c_t.clone()); + let h_t = self.norm_h.forward(h_t); + + let h_t = self.dropout.forward(h_t); + + LstmState::new(h_t, c_t) + } + + // Initialize cell state and hidden state if provided or with zeros + pub fn init_state(&self, batch_size: usize, device: &B::Device) -> LstmState { + let cell = Tensor::zeros([batch_size, self.hidden_size], device); + let hidden = Tensor::zeros([batch_size, self.hidden_size], device); + + LstmState::new(cell, hidden) + } +} + +/// Stacked LSTM implementation supporting multiple layers +/// Each layer processes the output of the previous layer +#[derive(Module, Debug)] +pub struct StackedLstm { + pub layers: Vec>, +} + +#[derive(Config, Debug)] +pub struct StackedLstmConfig { + pub input_size: usize, + pub hidden_size: usize, + pub num_layers: usize, + pub dropout: f64, +} + +impl StackedLstmConfig { + pub fn init(&self, device: &B::Device) -> StackedLstm { + let mut layers: Vec> = vec![]; + // Create list of LSTM cells, one for each layer + for i in 0..self.num_layers { + if i == 0 { + if i < self.num_layers - 1 { + layers.push( + LstmCellConfig::new(self.input_size, self.hidden_size, self.dropout) + .init(device), + ); + } else { + // No dropout on last layer + layers.push( + LstmCellConfig::new(self.input_size, self.hidden_size, 0.0).init(device), + ); + } + } else if i < self.num_layers - 1 { + layers.push( + LstmCellConfig::new(self.hidden_size, self.hidden_size, self.dropout) + .init(device), + ); + } else { + // No dropout on last layer + layers.push( + LstmCellConfig::new(self.hidden_size, self.hidden_size, 0.0).init(device), + ); + } + } + StackedLstm { layers } + } +} + +impl StackedLstm { + /// Process input sequence through stacked LSTM layers. + /// + /// Args: + /// x: Input tensor of shape (batch_size, seq_length, input_size) + /// states: Optional initial states for each layer + /// + /// Returns: + /// Tuple of (output, states) where output has shape (batch_size, seq_length, hidden_size) + /// and states is a vector of length num_layers, both cell and hidden state in each element have shape (batch_size, hidden_size) + pub fn forward( + &self, + x: Tensor, + states: Option>>, + ) -> (Tensor, Vec>) { + let [batch_size, seq_length, _] = x.dims(); + let device = x.device(); + + let mut states = match states { + None => { + let mut temp: Vec> = vec![]; + for layer in self.layers.iter() { + temp.push(layer.init_state(batch_size, &device)); + } + temp + } + _ => states.unwrap(), + }; + + let mut layer_outputs = vec![]; + for t in 0..seq_length { + let mut input_t = x + .clone() + .slice([None, Some((t as i64, t as i64 + 1)), None]) + .squeeze::<2>(1); + for (i, lstm_cell) in self.layers.iter().enumerate() { + let mut state: LstmState = + LstmState::new(states[i].cell.clone(), states[i].hidden.clone()); + state = lstm_cell.forward(input_t, state); + input_t = state.hidden.clone(); + states[i] = state; + } + layer_outputs.push(input_t); + } + + // Stack output along sequence dimension + let output = Tensor::stack(layer_outputs, 1); + + (output, states) + } +} + +/// Complete LSTM network with bidirectional support. +/// +/// In bidirectional mode: +/// - Forward LSTM processes sequence from left to right +/// - Backward LSTM processes sequence from right to left +/// - Outputs are concatenated for final prediction +#[derive(Module, Debug)] +pub struct LstmNetwork { + // Forward direction LSTM + pub stacked_lstm: StackedLstm, + // Optional backward direction LSTM for bidirectional processing + pub reverse_lstm: Option>, + pub dropout: Dropout, + pub fc: Linear, +} + +#[derive(Config, Debug)] +pub struct LstmNetworkConfig { + #[config(default = 1)] + pub input_size: usize, // Single feature (number sequence) + #[config(default = 32)] + pub hidden_size: usize, // Size of LSTM hidden state + #[config(default = 2)] + pub num_layers: usize, // Number of LSTM layers + #[config(default = 1)] + pub output_size: usize, // Predict one number + #[config(default = 0.1)] + pub dropout: f64, + #[config(default = true)] + pub bidirectional: bool, // Use bidirectional LSTM +} + +impl LstmNetworkConfig { + pub fn init(&self, device: &B::Device) -> LstmNetwork { + // Forward direction LSTM + let stacked_lstm = StackedLstmConfig::new( + self.input_size, + self.hidden_size, + self.num_layers, + self.dropout, + ) + .init(device); + + // Optional backward direction LSTM for bidirectional processing + let (reverse_lstm, hidden_size) = if self.bidirectional { + let lstm = StackedLstmConfig::new( + self.input_size, + self.hidden_size, + self.num_layers, + self.dropout, + ) + .init(device); + (Some(lstm), 2 * self.hidden_size) + } else { + (None, self.hidden_size) + }; + + let fc = LinearConfig::new(hidden_size, self.output_size).init(device); + let dropout = DropoutConfig::new(self.dropout).init(); + + LstmNetwork { + stacked_lstm, + reverse_lstm, + dropout, + fc, + } + } +} + +impl LstmNetwork { + /// Forward pass of the network. + /// + /// For bidirectional processing: + /// 1. Process sequence normally with forward LSTM + /// 2. Process reversed sequence with backward LSTM + /// 3. Concatenate both outputs + /// 4. Apply final linear transformation + /// + /// Args: + /// x: Input tensor of shape (batch_size, seq_length, input_size) + /// states: Optional initial states + /// + /// Returns: + /// Output tensor of shape (batch_size, output_size) + pub fn forward(&self, x: Tensor, states: Option>>) -> Tensor { + let seq_length = x.dims()[1] as i64; + // Forward direction + let (mut output, _states) = self.stacked_lstm.forward(x.clone(), states); + + output = match &self.reverse_lstm { + Some(reverse_lstm) => { + //Process sequence in reverse direction + let (mut reverse_output, _states) = reverse_lstm.forward(x.flip([1]), None); + // Flip back to align with forward sequence + reverse_output = reverse_output.flip([1]); + // Concatenate forward and backward outputs along the feature dimension + output = Tensor::cat(vec![output, reverse_output], 2); + output + } + None => output, + }; + + // Apply dropout before final layer + output = self.dropout.forward(output); + // Use final timestep output for prediction + self.fc.forward( + output + .slice([None, Some((seq_length - 1, seq_length)), None]) + .squeeze::<2>(1), + ) + } +} diff --git a/examples/modern-lstm/src/training.rs b/examples/modern-lstm/src/training.rs new file mode 100644 index 0000000000..9f6af81328 --- /dev/null +++ b/examples/modern-lstm/src/training.rs @@ -0,0 +1,131 @@ +use crate::dataset::{ + SequenceBatcher, SequenceDataset, NOISE_LEVEL, NUM_SEQUENCES, RANDOM_SEED, SEQ_LENGTH, +}; +use crate::model::{LstmNetwork, LstmNetworkConfig}; +use burn::{ + data::dataloader::DataLoaderBuilder, + module::AutodiffModule, + nn::loss::{MseLoss, Reduction::Mean}, + optim::{AdamConfig, GradientsParams, Optimizer}, + prelude::*, + record::CompactRecorder, + tensor::backend::AutodiffBackend, +}; + +#[derive(Config)] +pub struct TrainingConfig { + pub model: LstmNetworkConfig, + pub optimizer: AdamConfig, + + #[config(default = 30)] + pub num_epochs: usize, + #[config(default = 32)] + pub batch_size: usize, + #[config(default = 2)] + pub num_workers: usize, + #[config(default = 1e-3)] + pub lr: f64, +} + +// Create the directory to save the model and model config +fn create_artifact_dir(artifact_dir: &str) { + // Remove existing artifacts + std::fs::remove_dir_all(artifact_dir).ok(); + std::fs::create_dir_all(artifact_dir).ok(); +} + +pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { + create_artifact_dir(artifact_dir); + + // Save training config + config + .save(format!("{artifact_dir}/config.json")) + .expect("Config should be saved successfully"); + B::seed(RANDOM_SEED); + + // Create the model and optimizer + let mut model = config.model.init::(&device); + let mut optim = config.optimizer.init::>(); + + // Create the batcher + let batcher_train = SequenceBatcher::::new(device.clone()); + let batcher_valid = SequenceBatcher::::new(device.clone()); + + // Create the dataloaders + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(RANDOM_SEED) + .num_workers(config.num_workers) + .build(SequenceDataset::new(NUM_SEQUENCES, SEQ_LENGTH, NOISE_LEVEL)); + + let dataloader_valid = DataLoaderBuilder::new(batcher_valid) + .batch_size(config.batch_size) + .shuffle(RANDOM_SEED) + .num_workers(config.num_workers) + // 20% size of training + .build(SequenceDataset::new( + NUM_SEQUENCES / 5, + SEQ_LENGTH, + NOISE_LEVEL, + )); + + let train_num_items = dataloader_train.num_items(); + let valid_num_items = dataloader_valid.num_items(); + + println!("Starting training..."); + // Iterate over our training for X epochs + for epoch in 1..config.num_epochs + 1 { + // Initialize the training and validation metrics at the start of each epoch + let mut train_losses = vec![]; + let mut train_loss = 0.0; + let mut valid_losses = vec![]; + let mut valid_loss = 0.0; + + // Implement our training loop + for batch in dataloader_train.iter() { + let output = model.forward(batch.sequences, None); + let loss = MseLoss::new().forward(output, batch.targets.clone(), Mean); + train_loss += loss.clone().into_scalar().elem::() * batch.targets.dims()[0] as f32; + + // Gradients for the current backward pass + let grads = loss.backward(); + // Gradients linked to each parameter of the model + let grads = GradientsParams::from_grads(grads, &model); + // Update the model using the optimizer + model = optim.step(config.lr, model, grads); + } + + // The averaged train loss per epoch + let avg_train_loss = train_loss / train_num_items as f32; + train_losses.push(avg_train_loss); + + // Get the model without autodiff + let valid_model = model.valid(); + + // Implement our validation loop + for batch in dataloader_valid.iter() { + let output = valid_model.forward(batch.sequences, None); + let loss = MseLoss::new().forward(output, batch.targets.clone(), Mean); + valid_loss += loss.clone().into_scalar().elem::() * batch.targets.dims()[0] as f32; + } + // The averaged train loss per epoch + let avg_valid_loss = valid_loss / valid_num_items as f32; + valid_losses.push(avg_valid_loss); + + // Display the averaged training and validataion metrics every 10 epochs + if (epoch + 1) % 5 == 0 { + println!( + "Epoch {}/{}, Avg Loss {:.4}, Avg Val Loss: {:.4}", + epoch + 1, + config.num_epochs, + avg_train_loss, + avg_valid_loss, + ); + } + } + + // Save the trained model + model + .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) + .expect("Trained model should be saved successfully"); +}